Merge pull request #90 from ariel-frischer/fallback-tools
Experimental Tool Fallback Handler
This commit is contained in:
commit
9b0027a922
|
|
@ -175,6 +175,8 @@ More information is available in our [Usage Examples](https://docs.ra-aid.ai/cat
|
|||
- `--hil, -H`: Enable human-in-the-loop mode for interactive assistance during task execution
|
||||
- `--chat`: Enable chat mode with direct human interaction (implies --hil)
|
||||
- `--verbose`: Enable verbose logging output
|
||||
- `--experimental-fallback-handler`: Enable experimental fallback handler to attempt to fix too calls when they fail 3 times consecutively.
|
||||
- `--pretty-logger`: Enables panel markdown formatted logger messages for debugging purposes.
|
||||
- `--temperature`: LLM temperature (0.0-2.0) to control randomness in responses
|
||||
- `--disable-limit-tokens`: Disable token limiting for Anthropic Claude react agents
|
||||
- `--recursion-limit`: Maximum recursion depth for agent operations (default: 100)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from rich.text import Text
|
|||
from ra_aid import print_error, print_stage_header
|
||||
from ra_aid.__version__ import __version__
|
||||
from ra_aid.agent_utils import (
|
||||
AgentInterrupt,
|
||||
create_agent,
|
||||
run_agent_with_retry,
|
||||
run_planning_agent,
|
||||
|
|
@ -22,11 +21,15 @@ from ra_aid.config import (
|
|||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||
DEFAULT_RECURSION_LIMIT,
|
||||
DEFAULT_TEST_CMD_TIMEOUT,
|
||||
VALID_PROVIDERS,
|
||||
)
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.dependencies import check_dependencies
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.logging_config import get_logger, setup_logging
|
||||
from ra_aid.models_params import DEFAULT_TEMPERATURE, models_params
|
||||
from ra_aid.project_info import format_project_info, get_project_info
|
||||
from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
from ra_aid.tool_configs import get_chat_tools
|
||||
|
|
@ -45,14 +48,6 @@ def launch_webui(host: str, port: int):
|
|||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
VALID_PROVIDERS = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"openai-compatible",
|
||||
"deepseek",
|
||||
"gemini",
|
||||
]
|
||||
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
|
|
@ -145,6 +140,9 @@ Examples:
|
|||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Enable verbose logging output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretty-logger", action="store_true", help="Enable pretty logging output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
|
|
@ -156,6 +154,11 @@ Examples:
|
|||
action="store_false",
|
||||
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experimental-fallback-handler",
|
||||
action="store_true",
|
||||
help="Enable experimental fallback handler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursion-limit",
|
||||
type=int,
|
||||
|
|
@ -286,7 +289,7 @@ def is_stage_requested(stage: str) -> bool:
|
|||
def main():
|
||||
"""Main entry point for the ra-aid command line tool."""
|
||||
args = parse_arguments()
|
||||
setup_logging(args.verbose)
|
||||
setup_logging(args.verbose, args.pretty_logger)
|
||||
logger.debug("Starting RA.Aid with arguments: %s", args)
|
||||
|
||||
# Launch web interface if requested
|
||||
|
|
@ -304,7 +307,6 @@ def main():
|
|||
logger.debug("Environment validation successful")
|
||||
|
||||
# Validate model configuration early
|
||||
from ra_aid.models_params import models_params
|
||||
|
||||
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
|
||||
supports_temperature = model_config.get(
|
||||
|
|
@ -316,10 +318,10 @@ def main():
|
|||
if supports_temperature and args.temperature is None:
|
||||
args.temperature = model_config.get("default_temperature")
|
||||
if args.temperature is None:
|
||||
print_error(
|
||||
f"Temperature must be provided for model {args.model} which supports temperature"
|
||||
cpm(
|
||||
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
|
||||
)
|
||||
sys.exit(1)
|
||||
args.temperature = DEFAULT_TEMPERATURE
|
||||
logger.debug(
|
||||
f"Using default temperature {args.temperature} for model {args.model}"
|
||||
)
|
||||
|
|
@ -445,6 +447,7 @@ def main():
|
|||
"auto_test": args.auto_test,
|
||||
"test_cmd": args.test_cmd,
|
||||
"max_test_cmd_retries": args.max_test_cmd_retries,
|
||||
"experimental_fallback_handler": args.experimental_fallback_handler,
|
||||
"test_cmd_timeout": args.test_cmd_timeout,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,12 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
|
|||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, trim_messages
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
|
@ -23,10 +28,16 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.agents_alias import RAgents
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||
from ra_aid.console.formatting import print_error, print_stage_header
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.exceptions import (
|
||||
AgentInterrupt,
|
||||
FallbackToolExecutionError,
|
||||
ToolExecutionError,
|
||||
)
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.project_info import (
|
||||
|
|
@ -238,7 +249,7 @@ def create_agent(
|
|||
*,
|
||||
checkpointer: Any = None,
|
||||
agent_type: str = "default",
|
||||
) -> Any:
|
||||
):
|
||||
"""Create a react agent with the given configuration.
|
||||
|
||||
Args:
|
||||
|
|
@ -270,7 +281,7 @@ def create_agent(
|
|||
return create_react_agent(model, tools, **agent_kwargs)
|
||||
else:
|
||||
logger.debug("Using CiaynAgent agent instance")
|
||||
return CiaynAgent(model, tools, max_tokens=max_input_tokens)
|
||||
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
|
||||
|
||||
except Exception as e:
|
||||
# Default to REACT agent if provider/model detection fails
|
||||
|
|
@ -332,9 +343,6 @@ def run_research_agent(
|
|||
if memory is None:
|
||||
memory = MemorySaver()
|
||||
|
||||
if thread_id is None:
|
||||
thread_id = str(uuid.uuid4())
|
||||
|
||||
tools = get_research_tools(
|
||||
research_only=research_only,
|
||||
expert_enabled=expert_enabled,
|
||||
|
|
@ -405,8 +413,11 @@ def run_research_agent(
|
|||
display_project_status(project_info)
|
||||
|
||||
if agent is not None:
|
||||
logger.debug("Research agent completed successfully")
|
||||
_result = run_agent_with_retry(agent, prompt, run_config)
|
||||
logger.debug("Research agent created successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, config, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, run_config, none_or_fallback_handler
|
||||
)
|
||||
if _result:
|
||||
# Log research completion
|
||||
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
||||
|
|
@ -522,7 +533,10 @@ def run_web_research_agent(
|
|||
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
|
||||
|
||||
logger.debug("Web research agent completed successfully")
|
||||
_result = run_agent_with_retry(agent, prompt, run_config)
|
||||
none_or_fallback_handler = init_fallback_handler(agent, config, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, run_config, none_or_fallback_handler
|
||||
)
|
||||
if _result:
|
||||
# Log web research completion
|
||||
log_work_event(f"Completed web research phase for: {query}")
|
||||
|
|
@ -627,7 +641,10 @@ def run_planning_agent(
|
|||
try:
|
||||
print_stage_header("Planning Stage")
|
||||
logger.debug("Planning agent completed successfully")
|
||||
_result = run_agent_with_retry(agent, planning_prompt, run_config)
|
||||
none_or_fallback_handler = init_fallback_handler(agent, config, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, planning_prompt, run_config, none_or_fallback_handler
|
||||
)
|
||||
if _result:
|
||||
# Log planning completion
|
||||
log_work_event(f"Completed planning phase for: {base_task}")
|
||||
|
|
@ -732,7 +749,10 @@ def run_task_implementation_agent(
|
|||
|
||||
try:
|
||||
logger.debug("Implementation agent completed successfully")
|
||||
_result = run_agent_with_retry(agent, prompt, run_config)
|
||||
none_or_fallback_handler = init_fallback_handler(agent, config, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, run_config, none_or_fallback_handler
|
||||
)
|
||||
if _result:
|
||||
# Log task implementation completion
|
||||
log_work_event(f"Completed implementation of task: {task}")
|
||||
|
|
@ -775,49 +795,141 @@ def check_interrupt():
|
|||
raise AgentInterrupt("Interrupt requested")
|
||||
|
||||
|
||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||
"""Run an agent with retry logic for API errors."""
|
||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||
original_handler = None
|
||||
# New helper functions for run_agent_with_retry refactoring
|
||||
def _setup_interrupt_handling():
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
signal.signal(signal.SIGINT, _request_interrupt)
|
||||
return original_handler
|
||||
return None
|
||||
|
||||
|
||||
def _restore_interrupt_handling(original_handler):
|
||||
if original_handler and threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
|
||||
|
||||
def _increment_agent_depth():
|
||||
current_depth = _global_memory.get("agent_depth", 0)
|
||||
_global_memory["agent_depth"] = current_depth + 1
|
||||
|
||||
|
||||
def _decrement_agent_depth():
|
||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||
|
||||
|
||||
def reset_agent_completion_flags():
|
||||
_global_memory["plan_completed"] = False
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["completion_message"] = ""
|
||||
|
||||
|
||||
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
||||
return execute_test_command(config, original_prompt, test_attempts, auto_test)
|
||||
|
||||
|
||||
def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||
if isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
if "code" not in error_str or "429" not in error_str:
|
||||
raise e
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("Max retries reached, failing: %s", str(e))
|
||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
||||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||
delay = base_delay * (2**attempt)
|
||||
print_error(
|
||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
check_interrupt()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
|
||||
"""
|
||||
Determines the type of the agent.
|
||||
Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React".
|
||||
"""
|
||||
|
||||
if isinstance(agent, CiaynAgent):
|
||||
return "CiaynAgent"
|
||||
else:
|
||||
return "React"
|
||||
|
||||
|
||||
def init_fallback_handler(agent: RAgents, config: Dict[str, Any], tools: List[Any]):
|
||||
"""
|
||||
Initialize fallback handler if agent is of type "React" and experimental_fallback_handler is enabled; otherwise return None.
|
||||
"""
|
||||
if not config.get("experimental_fallback_handler", False):
|
||||
return None
|
||||
agent_type = get_agent_type(agent)
|
||||
if agent_type == "React":
|
||||
return FallbackHandler(config, tools)
|
||||
return None
|
||||
|
||||
|
||||
def _handle_fallback_response(
|
||||
error: ToolExecutionError,
|
||||
fallback_handler: Optional[FallbackHandler],
|
||||
agent: RAgents,
|
||||
msg_list: list,
|
||||
) -> None:
|
||||
"""
|
||||
Handle fallback response by invoking fallback_handler and updating msg_list.
|
||||
"""
|
||||
if not fallback_handler:
|
||||
return
|
||||
fallback_response = fallback_handler.handle_failure(error, agent, msg_list)
|
||||
agent_type = get_agent_type(agent)
|
||||
if fallback_response and agent_type == "React":
|
||||
msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response]
|
||||
msg_list.extend(msg_list_response)
|
||||
|
||||
|
||||
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
||||
for chunk in agent.stream({"messages": msg_list}, config):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
agent_type = get_agent_type(agent)
|
||||
print_agent_output(chunk, agent_type)
|
||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||
reset_agent_completion_flags()
|
||||
break
|
||||
|
||||
|
||||
def run_agent_with_retry(
|
||||
agent: RAgents,
|
||||
prompt: str,
|
||||
config: dict,
|
||||
fallback_handler: Optional[FallbackHandler] = None,
|
||||
) -> Optional[str]:
|
||||
"""Run an agent with retry logic for API errors."""
|
||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||
original_handler = _setup_interrupt_handling()
|
||||
max_retries = 20
|
||||
base_delay = 1
|
||||
test_attempts = 0
|
||||
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
||||
auto_test = config.get("auto_test", False)
|
||||
original_prompt = prompt
|
||||
msg_list = [HumanMessage(content=prompt)]
|
||||
|
||||
with InterruptibleSection():
|
||||
try:
|
||||
# Track agent execution depth
|
||||
current_depth = _global_memory.get("agent_depth", 0)
|
||||
_global_memory["agent_depth"] = current_depth + 1
|
||||
|
||||
_increment_agent_depth()
|
||||
for attempt in range(max_retries):
|
||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||
check_interrupt()
|
||||
try:
|
||||
for chunk in agent.stream(
|
||||
{"messages": [HumanMessage(content=prompt)]}, config
|
||||
):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
if _global_memory["plan_completed"]:
|
||||
_global_memory["plan_completed"] = False
|
||||
_global_memory["task_completed"] = False
|
||||
break
|
||||
if _global_memory["task_completed"]:
|
||||
_global_memory["task_completed"] = False
|
||||
break
|
||||
|
||||
# Execute test command if configured
|
||||
_run_agent_stream(agent, msg_list, config)
|
||||
if fallback_handler:
|
||||
fallback_handler.reset_fallback_handler()
|
||||
should_break, prompt, auto_test, test_attempts = (
|
||||
execute_test_command(
|
||||
config, original_prompt, test_attempts, auto_test
|
||||
_execute_test_command_wrapper(
|
||||
original_prompt, config, test_attempts, auto_test
|
||||
)
|
||||
)
|
||||
if should_break:
|
||||
|
|
@ -827,6 +939,13 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
|||
|
||||
logger.debug("Agent run completed successfully")
|
||||
return "Agent run completed successfully"
|
||||
except ToolExecutionError as e:
|
||||
_handle_fallback_response(e, fallback_handler, agent, msg_list)
|
||||
continue
|
||||
except FallbackToolExecutionError as e:
|
||||
msg_list.append(
|
||||
SystemMessage(f"FallbackToolExecutionError:{str(e)}")
|
||||
)
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
except (
|
||||
|
|
@ -836,35 +955,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
|||
APIError,
|
||||
ValueError,
|
||||
) as e:
|
||||
if isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
if "code" not in error_str or "429" not in error_str:
|
||||
raise # Re-raise ValueError if it's not a Lambda 429
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("Max retries reached, failing: %s", str(e))
|
||||
raise RuntimeError(
|
||||
f"Max retries ({max_retries}) exceeded. Last error: {e}"
|
||||
)
|
||||
logger.warning(
|
||||
"API error (attempt %d/%d): %s",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
str(e),
|
||||
)
|
||||
delay = base_delay * (2**attempt)
|
||||
print_error(
|
||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
check_interrupt()
|
||||
time.sleep(0.1)
|
||||
_handle_api_error(e, attempt, max_retries, base_delay)
|
||||
finally:
|
||||
# Reset depth tracking
|
||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||
|
||||
if (
|
||||
original_handler
|
||||
and threading.current_thread() is threading.main_thread()
|
||||
):
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
_decrement_agent_depth()
|
||||
_restore_interrupt_handling(original_handler)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,17 @@ import re
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT
|
||||
from ra_aid.tools.expert import get_model
|
||||
from ra_aid.tools.reflection import get_function_info
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -70,10 +76,11 @@ class CiaynAgent:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
tools: list,
|
||||
model: BaseChatModel,
|
||||
tools: list[BaseTool],
|
||||
max_history_messages: int = 50,
|
||||
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize the agent with a model and list of tools.
|
||||
|
||||
|
|
@ -82,18 +89,35 @@ class CiaynAgent:
|
|||
tools: List of tools available to the agent
|
||||
max_history_messages: Maximum number of messages to keep in chat history
|
||||
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
if config is None:
|
||||
config = {}
|
||||
self.config = config
|
||||
self.provider = config.get("provider", "openai")
|
||||
|
||||
self.model = model
|
||||
self.tools = tools
|
||||
self.max_history_messages = max_history_messages
|
||||
self.max_tokens = max_tokens
|
||||
self.chat_history = []
|
||||
self.available_functions = []
|
||||
for t in tools:
|
||||
self.available_functions.append(get_function_info(t.func))
|
||||
|
||||
self.fallback_handler = FallbackHandler(config, tools)
|
||||
self.sys_message = SystemMessage(
|
||||
"Execute efficiently yet completely as a fully autonomous agent."
|
||||
)
|
||||
self.error_message_template = "Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
|
||||
self.fallback_fixed_msg = HumanMessage(
|
||||
"Fallback tool handler has fixed the tool call see: <fallback tool call result> for the output."
|
||||
)
|
||||
|
||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||
"""Build the prompt for the agent including available tools and context."""
|
||||
base_prompt = ""
|
||||
|
||||
if last_result is not None:
|
||||
base_prompt += f"\n<last result>{last_result}</last result>"
|
||||
|
||||
|
|
@ -101,143 +125,62 @@ class CiaynAgent:
|
|||
functions_list = "\n\n".join(self.available_functions)
|
||||
|
||||
# Build the complete prompt without f-strings for the static parts
|
||||
base_prompt += (
|
||||
"""
|
||||
base_prompt += CIAYN_AGENT_BASE_PROMPT.format(functions_list=functions_list)
|
||||
|
||||
<agent instructions>
|
||||
You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration, but you will be called in a loop, so you will be able to accomplish the task over many iterations.
|
||||
The result of that function call will be given to you in the next message.
|
||||
Call one function at a time. Function arguments can be complex objects, long strings, etc. if needed.
|
||||
The user cannot see the results of function calls, so you have to explicitly use a tool like ask_human if you want them to see something.
|
||||
You must always respond with a single line of python that calls one of the available tools.
|
||||
Use as many steps as you need to in order to fully complete the task.
|
||||
Start by asking the user what they want.
|
||||
|
||||
You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal.
|
||||
You are expected to use as many steps as necessary to completely achieve the user's request, making many tool calls along the way.
|
||||
Think hard about what the best *next* tool call is, knowing that you can make as many calls as you need to after that.
|
||||
You typically don't want to keep calling the same function over and over with the same parameters.
|
||||
</agent instructions>
|
||||
|
||||
You must ONLY use ONE of the following functions (these are the ONLY functions that exist):
|
||||
|
||||
<available functions>"""
|
||||
+ functions_list
|
||||
+ """
|
||||
</available functions>
|
||||
|
||||
You may use any of the above functions to complete your job. Use the best one for the current step you are on. Be efficient, avoid getting stuck in repetitive loops, and do not hesitate to call functions which delegate your work to make your life easier.
|
||||
But you MUST NOT assume tools exist that are not in the above list, e.g. put_complete_file_contents.
|
||||
Consider your task done only once you have taken *ALL* the steps required to complete it.
|
||||
|
||||
--- EXAMPLE BAD OUTPUTS ---
|
||||
|
||||
This tool is not in available functions, so this is a bad tool call:
|
||||
|
||||
<example bad output>
|
||||
put_complete_file_contents(...)
|
||||
</example bad output>
|
||||
|
||||
This tool call has a syntax error (unclosed parenthesis, quotes), so it is bad:
|
||||
|
||||
<example bad output>
|
||||
put_complete_file_contents("asdf
|
||||
</example bad output>
|
||||
|
||||
This tool call is bad because it includes a message as well as backticks:
|
||||
|
||||
<example bad output>
|
||||
Sure, I'll make the following tool call to accomplish what you asked me:
|
||||
|
||||
```
|
||||
list_directory_tree('.')
|
||||
```
|
||||
</example bad output>
|
||||
|
||||
This tool call is bad because the output code is surrounded with backticks:
|
||||
|
||||
<example bad output>
|
||||
```
|
||||
list_directory_tree('.')
|
||||
```
|
||||
</example bad output>
|
||||
|
||||
The following is bad becasue it makes the same tool call multiple times in a row with the exact same parameters, for no reason, getting stuck in a loop:
|
||||
|
||||
<example bad output>
|
||||
<response 1>
|
||||
list_directory_tree('.')
|
||||
</response 1>
|
||||
<response 2>
|
||||
list_directory_tree('.')
|
||||
</response 2>
|
||||
</example bad output>
|
||||
|
||||
The following is bad because it makes more than one tool call in one response:
|
||||
|
||||
<example bad output>
|
||||
list_directory_tree('.')
|
||||
read_file_tool('README.md') # Now we've made
|
||||
</example bad output.
|
||||
|
||||
This is a good output because it calls the tool appropriately and with correct syntax:
|
||||
|
||||
--- EXAMPLE GOOD OUTPUTS ---
|
||||
|
||||
<example good output>
|
||||
request_research_and_implementation(\"\"\"
|
||||
Example query.
|
||||
\"\"\")
|
||||
</example good output>
|
||||
|
||||
This is good output because it uses a multiple line string when needed and properly calls the tool, does not output backticks or extra information:
|
||||
<example good output>
|
||||
run_programming_task(\"\"\"
|
||||
# Example Programming Task
|
||||
|
||||
Implement a widget factory satisfying the following requirements:
|
||||
|
||||
- Requirement A
|
||||
- Requirement B
|
||||
|
||||
...
|
||||
\"\"\")
|
||||
</example good output>
|
||||
|
||||
As an agent, you will carefully plan ahead, carefully analyze tool call responses, and adapt to circumstances in order to accomplish your goal.
|
||||
|
||||
You will make as many tool calls as you feel necessary in order to fully complete the task.
|
||||
|
||||
We're entrusting you with a lot of autonomy and power, so be efficient and don't mess up.
|
||||
|
||||
You have often been criticized for:
|
||||
|
||||
- Making the same function calls over and over, getting stuck in a loop.
|
||||
|
||||
DO NOT CLAIM YOU ARE FINISHED UNTIL YOU ACTUALLY ARE!
|
||||
Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||
)
|
||||
# base_prompt += "\n\nYou must reply with ONLY ONE of the functions given in available functions."
|
||||
|
||||
return base_prompt
|
||||
|
||||
def _execute_tool(self, code: str) -> str:
|
||||
def _execute_tool(self, msg: BaseMessage) -> str:
|
||||
"""Execute a tool call and return its result."""
|
||||
|
||||
code = msg.content
|
||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||
|
||||
try:
|
||||
code = code.strip()
|
||||
# code = code.replace("\n", " ")
|
||||
if code.startswith("```"):
|
||||
code = code[3:].strip()
|
||||
if code.endswith("```"):
|
||||
code = code[:-3].strip()
|
||||
|
||||
# if the eval fails, try to extract it via a model call
|
||||
if validate_function_call_pattern(code):
|
||||
functions_list = "\n\n".join(self.available_functions)
|
||||
code = _extract_tool_call(code, functions_list)
|
||||
code = self._extract_tool_call(code, functions_list)
|
||||
pass
|
||||
|
||||
result = eval(code.strip(), globals_dict)
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing code: {str(e)}"
|
||||
raise ToolExecutionError(error_msg)
|
||||
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
|
||||
tool_name = self.extract_tool_name(code)
|
||||
raise ToolExecutionError(
|
||||
error_msg, base_message=msg, tool_name=tool_name
|
||||
) from e
|
||||
|
||||
def extract_tool_name(self, code: str) -> str:
|
||||
match = re.match(r"\s*([\w_\-]+)\s*\(", code)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return ""
|
||||
|
||||
def handle_fallback_response(
|
||||
self, fallback_response: list[Any], e: ToolExecutionError
|
||||
) -> str:
|
||||
err_msg = HumanMessage(content=self.error_message_template.format(e=e))
|
||||
|
||||
if not fallback_response:
|
||||
self.chat_history.append(err_msg)
|
||||
return ""
|
||||
|
||||
self.chat_history.append(self.fallback_fixed_msg)
|
||||
msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n"
|
||||
# Passing the fallback raw invocation may confuse our llm, as invocation methods may differ.
|
||||
# msg += f"<fallback llm raw invocation>{fallback_response[0]}</fallback llm raw invocation>\n"
|
||||
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
|
||||
msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
|
||||
return msg
|
||||
|
||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||
|
|
@ -251,7 +194,6 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
@staticmethod
|
||||
def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int:
|
||||
"""Estimate number of tokens in content using simple byte length heuristic.
|
||||
|
||||
Estimates 1 token per 2.0 bytes of content. For messages, uses the content field.
|
||||
|
||||
Args:
|
||||
|
|
@ -277,6 +219,22 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
|
||||
return len(text.encode("utf-8")) // 2.0
|
||||
|
||||
def _extract_tool_call(self, code: str, functions_list: str) -> str:
|
||||
model = get_model()
|
||||
prompt = EXTRACT_TOOL_CALL_PROMPT.format(
|
||||
functions_list=functions_list, code=code
|
||||
)
|
||||
response = model.invoke(prompt)
|
||||
response = response.content
|
||||
|
||||
pattern = r"([\w_\-]+)\((.*?)\)"
|
||||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
raise ToolExecutionError("Failed to extract tool call")
|
||||
ma = matches[0][0].strip()
|
||||
mb = matches[0][1].strip().replace("\n", " ")
|
||||
return f"{ma}({mb})"
|
||||
|
||||
def _trim_chat_history(
|
||||
self, initial_messages: List[Any], chat_history: List[Any]
|
||||
) -> List[Any]:
|
||||
|
|
@ -315,72 +273,28 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
return initial_messages + chat_history
|
||||
|
||||
def stream(
|
||||
self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None
|
||||
self, messages_dict: Dict[str, List[Any]], _config: Dict[str, Any] = None
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""Stream agent responses in a format compatible with print_agent_output."""
|
||||
initial_messages = messages_dict.get("messages", [])
|
||||
chat_history = []
|
||||
self.chat_history = []
|
||||
last_result = None
|
||||
first_iteration = True
|
||||
|
||||
while True:
|
||||
base_prompt = self._build_prompt(None if first_iteration else last_result)
|
||||
chat_history.append(HumanMessage(content=base_prompt))
|
||||
|
||||
full_history = self._trim_chat_history(initial_messages, chat_history)
|
||||
response = self.model.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
"Execute efficiently yet completely as a fully autonomous agent."
|
||||
)
|
||||
]
|
||||
+ full_history
|
||||
)
|
||||
base_prompt = self._build_prompt(last_result)
|
||||
self.chat_history.append(HumanMessage(content=base_prompt))
|
||||
full_history = self._trim_chat_history(initial_messages, self.chat_history)
|
||||
response = self.model.invoke([self.sys_message] + full_history)
|
||||
|
||||
try:
|
||||
logger.debug(f"Code generated by agent: {response.content}")
|
||||
last_result = self._execute_tool(response.content)
|
||||
chat_history.append(response)
|
||||
first_iteration = False
|
||||
last_result = self._execute_tool(response)
|
||||
self.chat_history.append(response)
|
||||
self.fallback_handler.reset_fallback_handler()
|
||||
yield {}
|
||||
|
||||
except ToolExecutionError as e:
|
||||
chat_history.append(
|
||||
HumanMessage(
|
||||
content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
|
||||
)
|
||||
fallback_response = self.fallback_handler.handle_failure(
|
||||
e, self, self.chat_history
|
||||
)
|
||||
yield self._create_error_chunk(str(e))
|
||||
|
||||
|
||||
def _extract_tool_call(code: str, functions_list: str) -> str:
|
||||
from ra_aid.tools.expert import get_model
|
||||
|
||||
model = get_model()
|
||||
prompt = f"""
|
||||
I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example:
|
||||
|
||||
```
|
||||
run_programming_task("blah \" blah\" blah")
|
||||
```
|
||||
|
||||
The following tasks are allowed:
|
||||
|
||||
{functions_list}
|
||||
|
||||
I got this invalid response from the model, can you format it so it becomes a correct function call?
|
||||
|
||||
```
|
||||
{code}
|
||||
```
|
||||
"""
|
||||
response = model.invoke(prompt)
|
||||
response = response.content
|
||||
|
||||
pattern = r"([\w_\-]+)\((.*?)\)"
|
||||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
raise ToolExecutionError("Failed to extract tool call")
|
||||
ma = matches[0][0].strip()
|
||||
mb = matches[0][1].strip().replace("\n", " ")
|
||||
return f"{ma}({mb})"
|
||||
last_result = self.handle_fallback_response(fallback_response, e)
|
||||
yield {}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
|
||||
# Unfortunately need this to avoid Circular Imports
|
||||
if TYPE_CHECKING:
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
|
||||
RAgents = CompiledGraph | CiaynAgent
|
||||
else:
|
||||
RAgents = CompiledGraph
|
||||
|
|
@ -2,4 +2,17 @@
|
|||
|
||||
DEFAULT_RECURSION_LIMIT = 100
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
||||
DEFAULT_MAX_TOOL_FAILURES = 3
|
||||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||
RETRY_FALLBACK_COUNT = 3
|
||||
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
||||
|
||||
|
||||
VALID_PROVIDERS = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"openai-compatible",
|
||||
"deepseek",
|
||||
"gemini",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,18 +1,23 @@
|
|||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
||||
# Import shared console instance
|
||||
from .formatting import console
|
||||
|
||||
|
||||
def print_agent_output(chunk: Dict[str, Any]) -> None:
|
||||
def print_agent_output(
|
||||
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
|
||||
) -> None:
|
||||
"""Print only the agent's message content, not tool calls.
|
||||
|
||||
Args:
|
||||
chunk: A dictionary containing agent or tool messages
|
||||
chunk: A dictionary containing agent or tool messages.
|
||||
agent_type: Specifies the type of agent. 'CiaynAgent' handles tool errors internally, while 'React' raises a ToolExecutionError.
|
||||
"""
|
||||
if "agent" in chunk and "messages" in chunk["agent"]:
|
||||
messages = chunk["agent"]["messages"]
|
||||
|
|
@ -33,10 +38,31 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
|
|||
elif "tools" in chunk and "messages" in chunk["tools"]:
|
||||
for msg in chunk["tools"]["messages"]:
|
||||
if msg.status == "error" and msg.content:
|
||||
err_msg = msg.content.strip()
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(msg.content.strip()),
|
||||
Markdown(err_msg),
|
||||
title="❌ Tool Error",
|
||||
border_style="red bold",
|
||||
)
|
||||
)
|
||||
tool_name = getattr(msg, "name", None)
|
||||
|
||||
# CiaynAgent handles this internally
|
||||
if agent_type == "React":
|
||||
raise ToolExecutionError(
|
||||
err_msg, tool_name=tool_name, base_message=msg
|
||||
)
|
||||
|
||||
|
||||
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
|
||||
"""
|
||||
Print a message using a Panel with Markdown formatting.
|
||||
|
||||
Args:
|
||||
message (str): The message content to display.
|
||||
title (Optional[str]): An optional title for the panel.
|
||||
border_style (str): Border style for the panel.
|
||||
"""
|
||||
|
||||
console.print(Panel(Markdown(message), title=title, border_style=border_style))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
"""Custom exceptions for RA.Aid."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
|
||||
class AgentInterrupt(Exception):
|
||||
"""Exception raised when an agent's execution is interrupted.
|
||||
|
|
@ -18,4 +22,18 @@ class ToolExecutionError(Exception):
|
|||
from other types of errors in the agent system.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
base_message: Optional[BaseMessage] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.base_message = base_message
|
||||
self.tool_name = tool_name
|
||||
|
||||
|
||||
class FallbackToolExecutionError(Exception):
|
||||
"""Exception raised when a fallback tool execution fails."""
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -0,0 +1,437 @@
|
|||
import json
|
||||
import re
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.graph.message import BaseMessage
|
||||
|
||||
from ra_aid.agents_alias import RAgents
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TOOL_FAILURES,
|
||||
FALLBACK_TOOL_MODEL_LIMIT,
|
||||
RETRY_FALLBACK_COUNT,
|
||||
)
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
|
||||
from ra_aid.llm import initialize_llm, validate_provider_env
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.tool_configs import get_all_tools
|
||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FallbackHandler:
|
||||
"""
|
||||
FallbackHandler manages fallback logic when tool execution fails.
|
||||
|
||||
It loads fallback models from configuration and validated provider settings,
|
||||
maintains failure counts, and triggers appropriate fallback methods for both
|
||||
prompt-based and function-calling tool invocations. It also resets internal
|
||||
counters when a tool call succeeds.
|
||||
"""
|
||||
|
||||
def __init__(self, config, tools):
|
||||
"""
|
||||
Initialize the FallbackHandler with the given configuration and tools.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary that may include fallback settings.
|
||||
tools (list): List of available tools.
|
||||
"""
|
||||
self.config = config
|
||||
self.tools: list[BaseTool] = tools
|
||||
self.fallback_enabled = config.get("experimental_fallback_handler", False)
|
||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.failed_messages: list[BaseMessage] = []
|
||||
self.current_failing_tool_name = ""
|
||||
self.current_tool_to_bind: None | BaseTool = None
|
||||
self.msg_list: list[BaseMessage] = []
|
||||
|
||||
cpm(
|
||||
"Fallback models selected: "
|
||||
+ ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
|
||||
title="Fallback Models",
|
||||
)
|
||||
|
||||
def _format_model(self, m: dict) -> str:
|
||||
return f"{m.get('model', '')} ({m.get('type', 'prompt')})"
|
||||
|
||||
def _load_fallback_tool_models(self, _config):
|
||||
"""
|
||||
Load and return fallback tool models based on the provided configuration.
|
||||
|
||||
If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names).
|
||||
Otherwise, this method filters the supported_top_tool_models based on provider environment validation,
|
||||
selecting up to FALLBACK_TOOL_MODEL_LIMIT models.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary.
|
||||
|
||||
Returns:
|
||||
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
|
||||
"""
|
||||
supported = []
|
||||
skipped = []
|
||||
for item in supported_top_tool_models:
|
||||
provider = item.get("provider")
|
||||
model_name = item.get("model")
|
||||
if validate_provider_env(provider):
|
||||
supported.append(item)
|
||||
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
|
||||
break
|
||||
else:
|
||||
skipped.append(model_name)
|
||||
final_models = []
|
||||
for item in supported:
|
||||
if "type" not in item:
|
||||
item["type"] = "prompt"
|
||||
item["model"] = item["model"].lower()
|
||||
final_models.append(item)
|
||||
message = "Fallback models selected: " + ", ".join(
|
||||
[m["model"] for m in final_models]
|
||||
)
|
||||
if skipped:
|
||||
message += (
|
||||
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
||||
+ ", ".join(skipped)
|
||||
)
|
||||
return final_models
|
||||
|
||||
def handle_failure(
|
||||
self, error: ToolExecutionError, agent: RAgents, msg_list: list[BaseMessage]
|
||||
):
|
||||
"""
|
||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||
|
||||
Args:
|
||||
error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded.
|
||||
agent: The agent instance on which fallback may be executed.
|
||||
"""
|
||||
if not self.fallback_enabled:
|
||||
return None
|
||||
|
||||
if self.tool_failure_consecutive_failures == 0:
|
||||
self.init_msg_list(msg_list)
|
||||
|
||||
failed_tool_call_name = self.extract_failed_tool_name(error)
|
||||
self._reset_on_new_failure(failed_tool_call_name)
|
||||
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}"
|
||||
)
|
||||
|
||||
self.current_failing_tool_name = failed_tool_call_name
|
||||
self.current_tool_to_bind = self._find_tool_to_bind(
|
||||
agent, failed_tool_call_name
|
||||
)
|
||||
|
||||
if hasattr(error, "base_message") and error.base_message:
|
||||
self.failed_messages.append(error.base_message)
|
||||
|
||||
self.tool_failure_consecutive_failures += 1
|
||||
logger.debug(
|
||||
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}"
|
||||
)
|
||||
|
||||
if (
|
||||
self.tool_failure_consecutive_failures >= self.max_failures
|
||||
and self.fallback_tool_models
|
||||
):
|
||||
logger.debug(
|
||||
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
||||
)
|
||||
return self.attempt_fallback()
|
||||
|
||||
def attempt_fallback(self):
|
||||
"""
|
||||
Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call.
|
||||
|
||||
Returns:
|
||||
List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
||||
)
|
||||
cpm(
|
||||
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||
title="Fallback Notification",
|
||||
)
|
||||
for fallback_model in self.fallback_tool_models:
|
||||
result_list = self.invoke_fallback(fallback_model)
|
||||
if result_list:
|
||||
return result_list
|
||||
|
||||
cpm("All fallback models have failed.", title="Fallback Failed")
|
||||
|
||||
current_failing_tool_name = self.current_failing_tool_name
|
||||
self.reset_fallback_handler()
|
||||
|
||||
raise FallbackToolExecutionError(
|
||||
f"All fallback models have failed for tool: {current_failing_tool_name}"
|
||||
)
|
||||
|
||||
def reset_fallback_handler(self):
|
||||
"""
|
||||
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
|
||||
"""
|
||||
self.tool_failure_consecutive_failures = 0
|
||||
self.failed_messages.clear()
|
||||
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
|
||||
self.current_failing_tool_name = ""
|
||||
self.current_tool_to_bind = None
|
||||
self.msg_list = []
|
||||
|
||||
def _reset_on_new_failure(self, failed_tool_call_name):
|
||||
if (
|
||||
self.current_failing_tool_name
|
||||
and failed_tool_call_name != self.current_failing_tool_name
|
||||
):
|
||||
logger.debug(
|
||||
"New failing tool name identified. Resetting consecutive tool failures."
|
||||
)
|
||||
self.reset_fallback_handler()
|
||||
|
||||
def extract_failed_tool_name(self, error: ToolExecutionError):
|
||||
if error.tool_name:
|
||||
failed_tool_call_name = error.tool_name
|
||||
else:
|
||||
msg = str(error)
|
||||
logger.debug("Error message: %s", msg)
|
||||
match = re.search(r"name=['\"](\w+)['\"]", msg)
|
||||
if match:
|
||||
failed_tool_call_name = str(match.group(1))
|
||||
logger.debug(
|
||||
"Extracted failed_tool_call_name using regex: %s",
|
||||
failed_tool_call_name,
|
||||
)
|
||||
else:
|
||||
failed_tool_call_name = "Tool execution error"
|
||||
raise FallbackToolExecutionError(
|
||||
"Fallback failed: Could not extract failed tool name."
|
||||
)
|
||||
|
||||
return failed_tool_call_name
|
||||
|
||||
def _find_tool_to_bind(self, agent, failed_tool_call_name):
|
||||
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
|
||||
tool_to_bind = None
|
||||
if hasattr(agent, "tools"):
|
||||
tool_to_bind = next(
|
||||
(t for t in agent.tools if t.func.__name__ == failed_tool_call_name),
|
||||
None,
|
||||
)
|
||||
if tool_to_bind is None:
|
||||
all_tools = get_all_tools()
|
||||
tool_to_bind = next(
|
||||
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
|
||||
None,
|
||||
)
|
||||
if tool_to_bind is None:
|
||||
# TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name
|
||||
raise FallbackToolExecutionError(
|
||||
f"Fallback failed failed_tool_call_name: '{failed_tool_call_name}' not found in any available tools."
|
||||
)
|
||||
return tool_to_bind
|
||||
|
||||
def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model):
|
||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||
# Force tool calling with tool_choice param.
|
||||
bound_model = simple_model.bind_tools(
|
||||
[self.current_tool_to_bind],
|
||||
tool_choice=self.current_failing_tool_name,
|
||||
)
|
||||
else:
|
||||
# Do not force tool calling (Prompt method)
|
||||
bound_model = simple_model.bind_tools([self.current_tool_to_bind])
|
||||
return bound_model
|
||||
|
||||
def invoke_fallback(self, fallback_model):
|
||||
"""
|
||||
Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model.
|
||||
|
||||
Args:
|
||||
fallback_model (dict): The fallback model to use.
|
||||
|
||||
Returns:
|
||||
The response from the fallback model invocation, or None if failed.
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}")
|
||||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
|
||||
bound_model = self._bind_tool_model(simple_model, fallback_model)
|
||||
|
||||
retry_model = bound_model.with_retry(
|
||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
)
|
||||
|
||||
msg_list = self.construct_prompt_msg_list()
|
||||
response = retry_model.invoke(msg_list)
|
||||
|
||||
tool_call = self.base_message_to_tool_call_dict(response)
|
||||
|
||||
tool_call_result = self.invoke_prompt_tool_call(tool_call)
|
||||
# cpm(str(tool_call_result), title="Fallback Tool Call Result")
|
||||
logger.debug(
|
||||
f"Fallback call successful with model: {self._format_model(fallback_model)}"
|
||||
)
|
||||
|
||||
self.reset_fallback_handler()
|
||||
return [response, tool_call_result]
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
raise
|
||||
logger.error(
|
||||
f"Fallback with model {self._format_model(fallback_model)} failed: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def construct_prompt_msg_list(self):
|
||||
"""
|
||||
Construct a list of chat messages for the fallback prompt.
|
||||
The initial message instructs the assistant that it is a fallback tool caller.
|
||||
Then includes the failed tool call messages from self.failed_messages.
|
||||
Finally, it appends a human message asking it to retry calling the tool with correct valid arguments.
|
||||
|
||||
Returns:
|
||||
list: A list of chat messages.
|
||||
"""
|
||||
prompt_msg_list: list[BaseMessage] = []
|
||||
prompt_msg_list.append(
|
||||
SystemMessage(
|
||||
content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages."
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Have some way to use the correct message type in the future, dont just convert everything to system message.
|
||||
# This may be difficult as each model type may require different chat structures and throw API errors.
|
||||
prompt_msg_list.extend(SystemMessage(str(msg)) for msg in self.msg_list)
|
||||
|
||||
if self.failed_messages:
|
||||
# Convert to system messages to avoid API errors asking for correct msg structure
|
||||
prompt_msg_list.extend(
|
||||
[SystemMessage(str(msg)) for msg in self.failed_messages]
|
||||
)
|
||||
|
||||
prompt_msg_list.append(
|
||||
HumanMessage(
|
||||
content=f"Retry using the tool: '{self.current_failing_tool_name}' with correct arguments and formatting."
|
||||
)
|
||||
)
|
||||
return prompt_msg_list
|
||||
|
||||
def invoke_prompt_tool_call(self, tool_call_request: dict):
|
||||
"""
|
||||
Invoke a tool call from a prompt-based fallback response.
|
||||
|
||||
Args:
|
||||
tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'.
|
||||
|
||||
Returns:
|
||||
The result of invoking the tool.
|
||||
"""
|
||||
tool_name_to_tool = {
|
||||
getattr(tool.func, "__name__", None): tool for tool in self.tools
|
||||
}
|
||||
name = tool_call_request["name"]
|
||||
arguments = tool_call_request["arguments"]
|
||||
if name in tool_name_to_tool:
|
||||
return tool_name_to_tool[name].invoke(arguments)
|
||||
elif (
|
||||
self.current_tool_to_bind is not None
|
||||
and getattr(self.current_tool_to_bind.func, "__name__", None) == name
|
||||
):
|
||||
return self.current_tool_to_bind.invoke(arguments)
|
||||
else:
|
||||
raise FallbackToolExecutionError(
|
||||
f"Tool '{name}' not found in available tools."
|
||||
)
|
||||
|
||||
def base_message_to_tool_call_dict(self, response: BaseMessage):
|
||||
"""
|
||||
Extracts a tool call dictionary from a BaseMessage.
|
||||
|
||||
Args:
|
||||
response: The response object containing tool call data.
|
||||
|
||||
Returns:
|
||||
A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found,
|
||||
otherwise None.
|
||||
"""
|
||||
tool_calls = self.get_tool_calls(response)
|
||||
|
||||
if not tool_calls:
|
||||
raise FallbackToolExecutionError(
|
||||
f"Could not extract tool_call_dict from response: {response}"
|
||||
)
|
||||
|
||||
if len(tool_calls) > 1:
|
||||
logger.warning("Multiple tool calls detected, using the first one")
|
||||
|
||||
tool_call = tool_calls[0]
|
||||
return {
|
||||
"id": tool_call["id"],
|
||||
"type": tool_call["type"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]),
|
||||
}
|
||||
|
||||
def _parse_tool_arguments(self, tool_arguments):
|
||||
"""
|
||||
Helper method to parse tool call arguments.
|
||||
If tool_arguments is a string, it returns the JSON-parsed dictionary.
|
||||
Otherwise, returns tool_arguments as is.
|
||||
"""
|
||||
if isinstance(tool_arguments, str):
|
||||
return json.loads(tool_arguments)
|
||||
return tool_arguments
|
||||
|
||||
def get_tool_calls(self, response: BaseMessage):
|
||||
"""
|
||||
Extracts tool calls list from a fallback response.
|
||||
|
||||
Args:
|
||||
response: The response object containing tool call data.
|
||||
|
||||
Returns:
|
||||
The tool calls list if present, otherwise None.
|
||||
"""
|
||||
tool_calls = None
|
||||
if hasattr(response, "additional_kwargs") and response.additional_kwargs.get(
|
||||
"tool_calls"
|
||||
):
|
||||
tool_calls = response.additional_kwargs.get("tool_calls")
|
||||
elif hasattr(response, "tool_calls"):
|
||||
tool_calls = response.tool_calls
|
||||
elif isinstance(response, dict) and response.get("additional_kwargs", {}).get(
|
||||
"tool_calls"
|
||||
):
|
||||
tool_calls = response.get("additional_kwargs").get("tool_calls")
|
||||
return tool_calls
|
||||
|
||||
def handle_failure_response(
|
||||
self, error: ToolExecutionError, agent, agent_type: str
|
||||
):
|
||||
"""
|
||||
Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React",
|
||||
return a list of SystemMessage objects wrapping each message from the fallback response.
|
||||
"""
|
||||
fallback_response = self.handle_failure(error, agent)
|
||||
if fallback_response and agent_type == "React":
|
||||
return [SystemMessage(str(msg)) for msg in fallback_response]
|
||||
return None
|
||||
|
||||
def init_msg_list(self, full_msg_list: list[BaseMessage]) -> None:
|
||||
first_two = full_msg_list[:2]
|
||||
last_two = full_msg_list[-2:]
|
||||
merged = first_two.copy()
|
||||
for msg in last_two:
|
||||
if msg not in merged:
|
||||
merged.append(msg)
|
||||
self.msg_list = merged
|
||||
|
|
@ -8,6 +8,7 @@ from langchain_openai import ChatOpenAI
|
|||
from openai import OpenAI
|
||||
|
||||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
from .models_params import models_params
|
||||
|
|
@ -96,9 +97,9 @@ def create_deepseek_client(
|
|||
return ChatDeepseekReasoner(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=0
|
||||
if is_expert
|
||||
else (temperature if temperature is not None else 1),
|
||||
temperature=(
|
||||
0 if is_expert else (temperature if temperature is not None else 1)
|
||||
),
|
||||
model=model_name,
|
||||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
|
|
@ -125,9 +126,9 @@ def create_openrouter_client(
|
|||
return ChatDeepseekReasoner(
|
||||
api_key=api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
temperature=0
|
||||
if is_expert
|
||||
else (temperature if temperature is not None else 1),
|
||||
temperature=(
|
||||
0 if is_expert else (temperature if temperature is not None else 1)
|
||||
),
|
||||
model=model_name,
|
||||
timeout=LLM_REQUEST_TIMEOUT,
|
||||
max_retries=LLM_MAX_RETRIES,
|
||||
|
|
@ -171,7 +172,12 @@ def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any
|
|||
"base_url": "https://api.deepseek.com",
|
||||
},
|
||||
}
|
||||
return configs.get(provider, {})
|
||||
config = configs.get(provider, {})
|
||||
if not config or not config.get("api_key"):
|
||||
raise ValueError(
|
||||
f"Missing required environment variable for provider: {provider}"
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
|
|
@ -222,8 +228,9 @@ def create_llm_client(
|
|||
temp_kwargs = {"temperature": 0} if supports_temperature else {}
|
||||
elif supports_temperature:
|
||||
if temperature is None:
|
||||
raise ValueError(
|
||||
f"Temperature must be provided for model {model_name} which supports temperature"
|
||||
temperature = 0.7
|
||||
cpm(
|
||||
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||
)
|
||||
temp_kwargs = {"temperature": temperature}
|
||||
else:
|
||||
|
|
@ -298,3 +305,19 @@ def initialize_llm(
|
|||
def initialize_expert_llm(provider: str, model_name: str) -> BaseChatModel:
|
||||
"""Initialize an expert language model client based on the specified provider and model."""
|
||||
return create_llm_client(provider, model_name, temperature=None, is_expert=True)
|
||||
|
||||
|
||||
def validate_provider_env(provider: str) -> bool:
|
||||
"""Check if the required environment variables for a provider are set."""
|
||||
required_vars = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
"openai-compatible": "OPENAI_API_KEY",
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
}
|
||||
key = required_vars.get(provider.lower())
|
||||
if key:
|
||||
return bool(os.getenv(key))
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -2,17 +2,54 @@ import logging
|
|||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
def setup_logging(verbose: bool = False) -> None:
|
||||
|
||||
class PrettyHandler(logging.Handler):
|
||||
def __init__(self, level=logging.NOTSET):
|
||||
super().__init__(level)
|
||||
self.console = Console()
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
# Determine title and style based on log level
|
||||
if record.levelno >= logging.CRITICAL:
|
||||
title = "🔥 CRITICAL"
|
||||
style = "bold red"
|
||||
elif record.levelno >= logging.ERROR:
|
||||
title = "❌ ERROR"
|
||||
style = "red"
|
||||
elif record.levelno >= logging.WARNING:
|
||||
title = "⚠️ WARNING"
|
||||
style = "yellow"
|
||||
elif record.levelno >= logging.INFO:
|
||||
title = "ℹ️ INFO"
|
||||
style = "green"
|
||||
else:
|
||||
title = "🐞 DEBUG"
|
||||
style = "blue"
|
||||
self.console.print(Panel(Markdown(msg.strip()), title=title, style=style))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
def setup_logging(verbose: bool = False, pretty: bool = False) -> None:
|
||||
logger = logging.getLogger("ra_aid")
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
if pretty:
|
||||
handler = PrettyHandler()
|
||||
else:
|
||||
print("USING STREAM HANDLER LOGGER")
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -984,3 +984,130 @@ Remember, if you do not make any tool call (e.g. ask_human to tell them a messag
|
|||
|
||||
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
|
||||
"""
|
||||
|
||||
EXTRACT_TOOL_CALL_PROMPT = """I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example:
|
||||
```
|
||||
run_programming_task("blah \" blah\" blah")
|
||||
```
|
||||
|
||||
The following tasks are allowed:
|
||||
|
||||
{functions_list}
|
||||
|
||||
I got this invalid response from the model, can you format it so it becomes a correct function call?
|
||||
|
||||
```
|
||||
{code}
|
||||
```"""
|
||||
|
||||
CIAYN_AGENT_BASE_PROMPT = """<agent instructions>
|
||||
You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration, but you will be called in a loop, so you will be able to accomplish the task over many iterations.
|
||||
The result of that function call will be given to you in the next message.
|
||||
Call one function at a time. Function arguments can be complex objects, long strings, etc. if needed.
|
||||
The user cannot see the results of function calls, so you have to explicitly use a tool like ask_human if you want them to see something.
|
||||
You must always respond with a single line of python that calls one of the available tools.
|
||||
Use as many steps as you need to in order to fully complete the task.
|
||||
Start by asking the user what they want.
|
||||
|
||||
You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal.
|
||||
You are expected to use as many steps as necessary to completely achieve the user's request, making many tool calls along the way.
|
||||
Think hard about what the best *next* tool call is, knowing that you can make as many calls as you need to after that.
|
||||
You typically don't want to keep calling the same function over and over with the same parameters.
|
||||
</agent instructions>
|
||||
|
||||
You must ONLY use ONE of the following functions (these are the ONLY functions that exist):
|
||||
|
||||
<available functions>{functions_list}
|
||||
</available functions>
|
||||
|
||||
You may use any of the above functions to complete your job. Use the best one for the current step you are on. Be efficient, avoid getting stuck in repetitive loops, and do not hesitate to call functions which delegate your work to make your life easier.
|
||||
But you MUST NOT assume tools exist that are not in the above list, e.g. write_file_tool.
|
||||
Consider your task done only once you have taken *ALL* the steps required to complete it.
|
||||
|
||||
--- EXAMPLE BAD OUTPUTS ---
|
||||
|
||||
This tool is not in available functions, so this is a bad tool call:
|
||||
|
||||
<example bad output>
|
||||
write_file_tool(...)
|
||||
</example bad output>
|
||||
|
||||
This tool call has a syntax error (unclosed parenthesis, quotes), so it is bad:
|
||||
|
||||
<example bad output>
|
||||
write_file_tool("asdf
|
||||
</example bad output>
|
||||
|
||||
This tool call is bad because it includes a message as well as backticks:
|
||||
|
||||
<example bad output>
|
||||
Sure, I'll make the following tool call to accomplish what you asked me:
|
||||
|
||||
```
|
||||
list_directory_tree('.')
|
||||
```
|
||||
</example bad output>
|
||||
|
||||
This tool call is bad because the output code is surrounded with backticks:
|
||||
|
||||
<example bad output>
|
||||
```
|
||||
list_directory_tree('.')
|
||||
```
|
||||
</example bad output>
|
||||
|
||||
The following is bad becasue it makes the same tool call multiple times in a row with the exact same parameters, for no reason, getting stuck in a loop:
|
||||
|
||||
<example bad output>
|
||||
<response 1>
|
||||
list_directory_tree('.')
|
||||
</response 1>
|
||||
<response 2>
|
||||
list_directory_tree('.')
|
||||
</response 2>
|
||||
</example bad output>
|
||||
|
||||
The following is bad because it makes more than one tool call in one response:
|
||||
|
||||
<example bad output>
|
||||
list_directory_tree('.')
|
||||
read_file_tool('README.md') # Now we've made
|
||||
</example bad output.
|
||||
|
||||
This is a good output because it calls the tool appropriately and with correct syntax:
|
||||
|
||||
--- EXAMPLE GOOD OUTPUTS ---
|
||||
|
||||
<example good output>
|
||||
request_research_and_implementation(\"\"\"
|
||||
Example query.
|
||||
\"\"\")
|
||||
</example good output>
|
||||
|
||||
This is good output because it uses a multiple line string when needed and properly calls the tool, does not output backticks or extra information:
|
||||
<example good output>
|
||||
run_programming_task(\"\"\"
|
||||
# Example Programming Task
|
||||
|
||||
Implement a widget factory satisfying the following requirements:
|
||||
|
||||
- Requirement A
|
||||
- Requirement B
|
||||
|
||||
...
|
||||
\"\"\")
|
||||
</example good output>
|
||||
|
||||
As an agent, you will carefully plan ahead, carefully analyze tool call responses, and adapt to circumstances in order to accomplish your goal.
|
||||
|
||||
You will make as many tool calls as you feel necessary in order to fully complete the task.
|
||||
|
||||
We're entrusting you with a lot of autonomy and power, so be efficient and don't mess up.
|
||||
|
||||
You have often been criticized for:
|
||||
|
||||
- Making the same function calls over and over, getting stuck in a loop.
|
||||
|
||||
DO NOT CLAIM YOU ARE FINISHED UNTIL YOU ACTUALLY ARE!
|
||||
Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from langchain_core.tools import BaseTool
|
||||
|
||||
from ra_aid.tools import (
|
||||
ask_expert,
|
||||
ask_human,
|
||||
|
|
@ -8,7 +10,6 @@ from ra_aid.tools import (
|
|||
emit_research_notes,
|
||||
fuzzy_find_project_files,
|
||||
list_directory_tree,
|
||||
plan_implementation_completed,
|
||||
read_file_tool,
|
||||
ripgrep_search,
|
||||
run_programming_task,
|
||||
|
|
@ -23,13 +24,13 @@ from ra_aid.tools.agent import (
|
|||
request_task_implementation,
|
||||
request_web_research,
|
||||
)
|
||||
from ra_aid.tools.memory import one_shot_completed
|
||||
from ra_aid.tools.memory import one_shot_completed, plan_implementation_completed
|
||||
|
||||
|
||||
# Read-only tools that don't modify system state
|
||||
def get_read_only_tools(
|
||||
human_interaction: bool = False, web_research_enabled: bool = False
|
||||
) -> list:
|
||||
):
|
||||
"""Get the list of read-only tools, optionally including human interaction tools.
|
||||
|
||||
Args:
|
||||
|
|
@ -63,6 +64,18 @@ def get_read_only_tools(
|
|||
return tools
|
||||
|
||||
|
||||
def get_all_tools() -> list[BaseTool]:
|
||||
"""Return a list containing all available tools from different groups."""
|
||||
all_tools = []
|
||||
all_tools.extend(get_read_only_tools())
|
||||
all_tools.extend(MODIFICATION_TOOLS)
|
||||
all_tools.extend(EXPERT_TOOLS)
|
||||
all_tools.extend(RESEARCH_TOOLS)
|
||||
all_tools.extend(get_web_research_tools())
|
||||
all_tools.extend(get_chat_tools())
|
||||
return all_tools
|
||||
|
||||
|
||||
# Define constant tool groups
|
||||
READ_ONLY_TOOLS = get_read_only_tools()
|
||||
# MODIFICATION_TOOLS = [run_programming_task, put_complete_file_contents]
|
||||
|
|
@ -85,7 +98,7 @@ def get_research_tools(
|
|||
expert_enabled: bool = True,
|
||||
human_interaction: bool = False,
|
||||
web_research_enabled: bool = False,
|
||||
) -> list:
|
||||
):
|
||||
"""Get the list of research tools based on mode and whether expert is enabled.
|
||||
|
||||
Args:
|
||||
|
|
@ -165,7 +178,7 @@ def get_implementation_tools(
|
|||
return tools
|
||||
|
||||
|
||||
def get_web_research_tools(expert_enabled: bool = True) -> list:
|
||||
def get_web_research_tools(expert_enabled: bool = True):
|
||||
"""Get the list of tools available for web research.
|
||||
|
||||
Args:
|
||||
|
|
@ -185,9 +198,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
|
|||
return tools
|
||||
|
||||
|
||||
def get_chat_tools(
|
||||
expert_enabled: bool = True, web_research_enabled: bool = False
|
||||
) -> list:
|
||||
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False):
|
||||
"""Get the list of tools available in chat mode.
|
||||
|
||||
Chat mode includes research and implementation capabilities but excludes
|
||||
|
|
|
|||
|
|
@ -0,0 +1,529 @@
|
|||
from ra_aid.config import VALID_PROVIDERS
|
||||
|
||||
# Data extracted at 2/10/2025:
|
||||
# https://gorilla.cs.berkeley.edu/leaderboard.html
|
||||
# In order of overall_acc
|
||||
leaderboard_data = [
|
||||
{
|
||||
"overall_acc": 74.31,
|
||||
"model": "watt-tool-70B",
|
||||
"type": "FC",
|
||||
"link": "https://huggingface.co/watt-ai/watt-tool-70B/",
|
||||
"cost": "N/A",
|
||||
"latency": 3.4,
|
||||
"ast_summary": 84.06,
|
||||
"exec_summary": 89.39,
|
||||
"live_ast_acc": 77.74,
|
||||
"multi_turn_acc": 58.75,
|
||||
"relevance": 94.44,
|
||||
"irrelevance": 76.32,
|
||||
"organization": "Watt AI Lab",
|
||||
"license": "Apache-2.0",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 72.08,
|
||||
"model": "gpt-4o-2024-11-20",
|
||||
"type": "Prompt",
|
||||
"link": "https://openai.com/index/hello-gpt-4o/",
|
||||
"cost": 13.54,
|
||||
"latency": 0.78,
|
||||
"ast_summary": 88.1,
|
||||
"exec_summary": 89.38,
|
||||
"live_ast_acc": 79.83,
|
||||
"multi_turn_acc": 47.62,
|
||||
"relevance": 83.33,
|
||||
"irrelevance": 83.76,
|
||||
"organization": "OpenAI",
|
||||
"license": "Proprietary",
|
||||
"provider": "openai",
|
||||
},
|
||||
{
|
||||
"overall_acc": 69.58,
|
||||
"model": "gpt-4o-2024-11-20",
|
||||
"type": "FC",
|
||||
"link": "https://openai.com/index/hello-gpt-4o/",
|
||||
"cost": 8.23,
|
||||
"latency": 1.11,
|
||||
"ast_summary": 87.42,
|
||||
"exec_summary": 89.2,
|
||||
"live_ast_acc": 79.65,
|
||||
"multi_turn_acc": 41,
|
||||
"relevance": 83.33,
|
||||
"irrelevance": 83.15,
|
||||
"organization": "OpenAI",
|
||||
"license": "Proprietary",
|
||||
"provider": "openai",
|
||||
},
|
||||
{
|
||||
"overall_acc": 67.98,
|
||||
"model": "watt-tool-8B",
|
||||
"type": "FC",
|
||||
"link": "https://huggingface.co/watt-ai/watt-tool-8B/",
|
||||
"cost": "N/A",
|
||||
"latency": 1.31,
|
||||
"ast_summary": 86.56,
|
||||
"exec_summary": 89.34,
|
||||
"live_ast_acc": 76.5,
|
||||
"multi_turn_acc": 39.12,
|
||||
"relevance": 83.33,
|
||||
"irrelevance": 83.15,
|
||||
"organization": "Watt AI Lab",
|
||||
"license": "Apache-2.0",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 67.88,
|
||||
"model": "GPT-4-turbo-2024-04-09",
|
||||
"type": "FC",
|
||||
"link": "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
|
||||
"cost": 33.22,
|
||||
"latency": 2.47,
|
||||
"ast_summary": 84.73,
|
||||
"exec_summary": 85.21,
|
||||
"live_ast_acc": 80.5,
|
||||
"multi_turn_acc": 38.12,
|
||||
"relevance": 72.22,
|
||||
"irrelevance": 83.81,
|
||||
"organization": "OpenAI",
|
||||
"license": "Proprietary",
|
||||
"provider": "openai",
|
||||
},
|
||||
{
|
||||
"overall_acc": 66.73,
|
||||
"model": "o1-2024-12-17",
|
||||
"type": "Prompt",
|
||||
"link": "https://openai.com/o1/",
|
||||
"cost": 102.47,
|
||||
"latency": 5.3,
|
||||
"ast_summary": 85.67,
|
||||
"exec_summary": 79.77,
|
||||
"live_ast_acc": 80.63,
|
||||
"multi_turn_acc": 36,
|
||||
"relevance": 72.22,
|
||||
"irrelevance": 87.78,
|
||||
"organization": "OpenAI",
|
||||
"license": "Proprietary",
|
||||
"provider": "openai",
|
||||
},
|
||||
{
|
||||
"overall_acc": 64.1,
|
||||
"model": "GPT-4o-mini-2024-07-18",
|
||||
"type": "FC",
|
||||
"link": "https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/",
|
||||
"cost": 0.51,
|
||||
"latency": 1.49,
|
||||
"ast_summary": 85.21,
|
||||
"exec_summary": 83.57,
|
||||
"live_ast_acc": 74.41,
|
||||
"multi_turn_acc": 34.12,
|
||||
"relevance": 83.33,
|
||||
"irrelevance": 74.75,
|
||||
"organization": "OpenAI",
|
||||
"license": "Proprietary",
|
||||
"provider": "openai",
|
||||
},
|
||||
{
|
||||
"overall_acc": 62.79,
|
||||
"model": "o1-mini-2024-09-12",
|
||||
"type": "Prompt",
|
||||
"link": "https://openai.com/index/openai-o1-mini-advancing-cost-efficient-reasoning/",
|
||||
"cost": 29.76,
|
||||
"latency": 8.44,
|
||||
"ast_summary": 78.92,
|
||||
"exec_summary": 82.7,
|
||||
"live_ast_acc": 78.14,
|
||||
"multi_turn_acc": 28.25,
|
||||
"relevance": 61.11,
|
||||
"irrelevance": 89.62,
|
||||
"organization": "OpenAI",
|
||||
"license": "Proprietary",
|
||||
"provider": "openai",
|
||||
},
|
||||
{
|
||||
"overall_acc": 62.73,
|
||||
"model": "Functionary-Medium-v3.1",
|
||||
"type": "FC",
|
||||
"link": "https://huggingface.co/meetkai/functionary-medium-v3.1",
|
||||
"cost": "N/A",
|
||||
"latency": 14.06,
|
||||
"ast_summary": 89.88,
|
||||
"exec_summary": 91.32,
|
||||
"live_ast_acc": 76.63,
|
||||
"multi_turn_acc": 21.62,
|
||||
"relevance": 72.22,
|
||||
"irrelevance": 76.08,
|
||||
"organization": "MeetKai",
|
||||
"license": "MIT",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 62.19,
|
||||
"model": "Gemini-1.5-Pro-002",
|
||||
"type": "Prompt",
|
||||
"link": "https://deepmind.google/technologies/gemini/pro/",
|
||||
"cost": 7.05,
|
||||
"latency": 5.94,
|
||||
"ast_summary": 88.58,
|
||||
"exec_summary": 91.27,
|
||||
"live_ast_acc": 76.72,
|
||||
"multi_turn_acc": 20.75,
|
||||
"relevance": 72.22,
|
||||
"irrelevance": 78.15,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
{
|
||||
"overall_acc": 61.83,
|
||||
"model": "Hammer2.1-7b",
|
||||
"type": "FC",
|
||||
"link": "https://huggingface.co/MadeAgents/Hammer2.1-7b",
|
||||
"cost": "N/A",
|
||||
"latency": 2.08,
|
||||
"ast_summary": 88.65,
|
||||
"exec_summary": 85.48,
|
||||
"live_ast_acc": 75.11,
|
||||
"multi_turn_acc": 23.5,
|
||||
"relevance": 82.35,
|
||||
"irrelevance": 78.59,
|
||||
"organization": "MadeAgents",
|
||||
"license": "cc-by-nc-4.0",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 61.74,
|
||||
"model": "Gemini-2.0-Flash-Exp",
|
||||
"type": "Prompt",
|
||||
"link": "https://deepmind.google/technologies/gemini/flash/",
|
||||
"cost": 0.0,
|
||||
"latency": 1.18,
|
||||
"ast_summary": 89.96,
|
||||
"exec_summary": 79.89,
|
||||
"live_ast_acc": 82.01,
|
||||
"multi_turn_acc": 17.88,
|
||||
"relevance": 77.78,
|
||||
"irrelevance": 86.44,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
{
|
||||
"overall_acc": 61.38,
|
||||
"model": "Amazon-Nova-Pro-v1:0",
|
||||
"type": "FC",
|
||||
"link": "https://aws.amazon.com/cn/ai/generative-ai/nova/",
|
||||
"cost": 5.26,
|
||||
"latency": 2.67,
|
||||
"ast_summary": 84.46,
|
||||
"exec_summary": 85.64,
|
||||
"live_ast_acc": 74.32,
|
||||
"multi_turn_acc": 26.12,
|
||||
"relevance": 77.78,
|
||||
"irrelevance": 70.98,
|
||||
"organization": "Amazon",
|
||||
"license": "Proprietary",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 61.31,
|
||||
"model": "Qwen2.5-72B-Instruct",
|
||||
"type": "Prompt",
|
||||
"link": "https://huggingface.co/Qwen/Qwen2.5-72B-Instruct",
|
||||
"cost": "N/A",
|
||||
"latency": 3.72,
|
||||
"ast_summary": 90.81,
|
||||
"exec_summary": 92.7,
|
||||
"live_ast_acc": 75.3,
|
||||
"multi_turn_acc": 18,
|
||||
"relevance": 100,
|
||||
"irrelevance": 72.81,
|
||||
"organization": "Qwen",
|
||||
"license": "qwen",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 60.97,
|
||||
"model": "Gemini-1.5-Pro-002",
|
||||
"type": "FC",
|
||||
"link": "https://deepmind.google/technologies/gemini/pro/",
|
||||
"cost": 5.39,
|
||||
"latency": 2.07,
|
||||
"ast_summary": 87.29,
|
||||
"exec_summary": 84.61,
|
||||
"live_ast_acc": 76.28,
|
||||
"multi_turn_acc": 21.62,
|
||||
"relevance": 72.22,
|
||||
"irrelevance": 76.9,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
{
|
||||
"overall_acc": 60.89,
|
||||
"model": "GPT-4o-mini-2024-07-18",
|
||||
"type": "Prompt",
|
||||
"link": "https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/",
|
||||
"cost": 0.84,
|
||||
"latency": 1.31,
|
||||
"ast_summary": 86.77,
|
||||
"exec_summary": 80.84,
|
||||
"live_ast_acc": 76.5,
|
||||
"multi_turn_acc": 22,
|
||||
"relevance": 83.33,
|
||||
"irrelevance": 80.67,
|
||||
"organization": "OpenAI",
|
||||
"license": "Proprietary",
|
||||
"provider": "openai",
|
||||
},
|
||||
{
|
||||
"overall_acc": 60.59,
|
||||
"model": "Gemini-2.0-Flash-Exp",
|
||||
"type": "FC",
|
||||
"link": "https://deepmind.google/technologies/gemini/flash/",
|
||||
"cost": 0.0,
|
||||
"latency": 0.85,
|
||||
"ast_summary": 85.1,
|
||||
"exec_summary": 77.46,
|
||||
"live_ast_acc": 79.03,
|
||||
"multi_turn_acc": 20.25,
|
||||
"relevance": 55.56,
|
||||
"irrelevance": 91.51,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
{
|
||||
"overall_acc": 60.46,
|
||||
"model": "Gemini-1.5-Pro-001",
|
||||
"type": "Prompt",
|
||||
"link": "https://deepmind.google/technologies/gemini/pro/",
|
||||
"cost": 7.0,
|
||||
"latency": 1.54,
|
||||
"ast_summary": 85.56,
|
||||
"exec_summary": 85.77,
|
||||
"live_ast_acc": 76.68,
|
||||
"multi_turn_acc": 18.88,
|
||||
"relevance": 55.56,
|
||||
"irrelevance": 84.81,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
{
|
||||
"overall_acc": 60.38,
|
||||
"model": "Gemini-Exp-1206",
|
||||
"type": "FC",
|
||||
"link": "https://blog.google/feed/gemini-exp-1206/",
|
||||
"cost": 0.0,
|
||||
"latency": 3.42,
|
||||
"ast_summary": 85.17,
|
||||
"exec_summary": 80.86,
|
||||
"live_ast_acc": 78.54,
|
||||
"multi_turn_acc": 20.25,
|
||||
"relevance": 77.78,
|
||||
"irrelevance": 79.64,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
{
|
||||
"overall_acc": 59.67,
|
||||
"model": "Qwen2.5-32B-Instruct",
|
||||
"type": "Prompt",
|
||||
"link": "https://huggingface.co/Qwen/Qwen2.5-32B-Instruct",
|
||||
"cost": "N/A",
|
||||
"latency": 2.26,
|
||||
"ast_summary": 85.81,
|
||||
"exec_summary": 89.79,
|
||||
"live_ast_acc": 74.23,
|
||||
"multi_turn_acc": 17.75,
|
||||
"relevance": 100,
|
||||
"irrelevance": 73.75,
|
||||
"organization": "Qwen",
|
||||
"license": "apache-2.0",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 59.57,
|
||||
"model": "GPT-4-turbo-2024-04-09",
|
||||
"type": "Prompt",
|
||||
"link": "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
|
||||
"cost": 58.87,
|
||||
"latency": 1.24,
|
||||
"ast_summary": 90.88,
|
||||
"exec_summary": 89.45,
|
||||
"live_ast_acc": 63.84,
|
||||
"multi_turn_acc": 30.25,
|
||||
"relevance": 100,
|
||||
"irrelevance": 35.57,
|
||||
"organization": "OpenAI",
|
||||
"license": "Proprietary",
|
||||
"provider": "openai",
|
||||
},
|
||||
{
|
||||
"overall_acc": 59.42,
|
||||
"model": "Gemini-1.5-Pro-001",
|
||||
"type": "FC",
|
||||
"link": "https://deepmind.google/technologies/gemini/pro/",
|
||||
"cost": 5.1,
|
||||
"latency": 1.43,
|
||||
"ast_summary": 84.33,
|
||||
"exec_summary": 87.95,
|
||||
"live_ast_acc": 76.23,
|
||||
"multi_turn_acc": 16,
|
||||
"relevance": 50,
|
||||
"irrelevance": 84.39,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
{
|
||||
"overall_acc": 59.07,
|
||||
"model": "Hammer2.1-3b",
|
||||
"type": "FC",
|
||||
"link": "https://huggingface.co/MadeAgents/Hammer2.1-3b",
|
||||
"cost": "N/A",
|
||||
"latency": 1.95,
|
||||
"ast_summary": 86.85,
|
||||
"exec_summary": 84.09,
|
||||
"live_ast_acc": 74.04,
|
||||
"multi_turn_acc": 17.38,
|
||||
"relevance": 82.35,
|
||||
"irrelevance": 81.87,
|
||||
"organization": "MadeAgents",
|
||||
"license": "qwen-research",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 58.45,
|
||||
"model": "mistral-large-2407",
|
||||
"type": "FC",
|
||||
"link": "https://mistral.ai/news/mistral-large-2407/",
|
||||
"cost": 12.68,
|
||||
"latency": 3.12,
|
||||
"ast_summary": 86.81,
|
||||
"exec_summary": 84.38,
|
||||
"live_ast_acc": 69.88,
|
||||
"multi_turn_acc": 23.75,
|
||||
"relevance": 72.22,
|
||||
"irrelevance": 52.85,
|
||||
"organization": "Mistral AI",
|
||||
"license": "Proprietary",
|
||||
"provider": "mistral",
|
||||
},
|
||||
{
|
||||
"overall_acc": 58.42,
|
||||
"model": "ToolACE-8B",
|
||||
"type": "FC",
|
||||
"link": "https://huggingface.co/Team-ACE/ToolACE-8B",
|
||||
"cost": "N/A",
|
||||
"latency": 5.24,
|
||||
"ast_summary": 87.54,
|
||||
"exec_summary": 89.21,
|
||||
"live_ast_acc": 78.59,
|
||||
"multi_turn_acc": 7.75,
|
||||
"relevance": 83.33,
|
||||
"irrelevance": 87.88,
|
||||
"organization": "Huawei Noah & USTC",
|
||||
"license": "Apache-2.0",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 57.78,
|
||||
"model": "xLAM-8x22b-r",
|
||||
"type": "FC",
|
||||
"link": "https://huggingface.co/Salesforce/xLAM-8x22b-r",
|
||||
"cost": "N/A",
|
||||
"latency": 9.26,
|
||||
"ast_summary": 83.69,
|
||||
"exec_summary": 87.88,
|
||||
"live_ast_acc": 72.59,
|
||||
"multi_turn_acc": 16.25,
|
||||
"relevance": 88.89,
|
||||
"irrelevance": 67.81,
|
||||
"organization": "Salesforce",
|
||||
"license": "cc-by-nc-4.0",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 57.68,
|
||||
"model": "Qwen2.5-14B-Instruct",
|
||||
"type": "Prompt",
|
||||
"link": "https://huggingface.co/Qwen/Qwen2.5-14B-Instruct",
|
||||
"cost": "N/A",
|
||||
"latency": 2.02,
|
||||
"ast_summary": 85.69,
|
||||
"exec_summary": 88.84,
|
||||
"live_ast_acc": 74.14,
|
||||
"multi_turn_acc": 12.25,
|
||||
"relevance": 77.78,
|
||||
"irrelevance": 77.06,
|
||||
"organization": "Qwen",
|
||||
"license": "apache-2.0",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 57.23,
|
||||
"model": "DeepSeek-V3",
|
||||
"type": "FC",
|
||||
"link": "https://api-docs.deepseek.com/news/news1226",
|
||||
"cost": "N/A",
|
||||
"latency": 2.58,
|
||||
"ast_summary": 89.17,
|
||||
"exec_summary": 83.39,
|
||||
"live_ast_acc": 68.41,
|
||||
"multi_turn_acc": 18.62,
|
||||
"relevance": 88.89,
|
||||
"irrelevance": 59.36,
|
||||
"organization": "DeepSeek",
|
||||
"license": "DeepSeek License",
|
||||
"provider": "unknown",
|
||||
},
|
||||
{
|
||||
"overall_acc": 57.09,
|
||||
"model": "Gemini-1.5-Flash-001",
|
||||
"type": "Prompt",
|
||||
"link": "https://deepmind.google/technologies/gemini/flash/",
|
||||
"cost": 0.48,
|
||||
"latency": 0.71,
|
||||
"ast_summary": 85.69,
|
||||
"exec_summary": 83.59,
|
||||
"live_ast_acc": 68.9,
|
||||
"multi_turn_acc": 19.5,
|
||||
"relevance": 83.33,
|
||||
"irrelevance": 62.78,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
{
|
||||
"overall_acc": 56.79,
|
||||
"model": "Gemini-1.5-Flash-002",
|
||||
"type": "Prompt",
|
||||
"link": "https://deepmind.google/technologies/gemini/flash/",
|
||||
"cost": 0.46,
|
||||
"latency": 0.81,
|
||||
"ast_summary": 81.65,
|
||||
"exec_summary": 80.64,
|
||||
"live_ast_acc": 76.72,
|
||||
"multi_turn_acc": 12.5,
|
||||
"relevance": 83.33,
|
||||
"irrelevance": 78.49,
|
||||
"organization": "Google",
|
||||
"license": "Proprietary",
|
||||
"provider": "google",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
supported_top_tool_models = [
|
||||
{
|
||||
"cost": item["cost"],
|
||||
"model": item["model"],
|
||||
"type": item["type"],
|
||||
"provider": item["provider"],
|
||||
}
|
||||
for item in leaderboard_data
|
||||
if item["provider"] in VALID_PROVIDERS
|
||||
]
|
||||
|
|
@ -10,7 +10,7 @@ from rich.panel import Panel
|
|||
from rich.text import Text
|
||||
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import models_params, DEFAULT_BASE_LATENCY
|
||||
from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
|
||||
from ra_aid.proc.interactive import run_interactive_command
|
||||
from ra_aid.text.processing import truncate_output
|
||||
from ra_aid.tools.memory import _global_memory, log_work_event
|
||||
|
|
@ -142,7 +142,11 @@ def run_programming_task(
|
|||
# Get provider/model specific latency coefficient
|
||||
provider = _global_memory.get("config", {}).get("provider", "")
|
||||
model = _global_memory.get("config", {}).get("model", "")
|
||||
latency = models_params.get(provider, {}).get(model, {}).get("latency_coefficient", DEFAULT_BASE_LATENCY)
|
||||
latency = (
|
||||
models_params.get(provider, {})
|
||||
.get(model, {})
|
||||
.get("latency_coefficient", DEFAULT_BASE_LATENCY)
|
||||
)
|
||||
|
||||
result = run_interactive_command(command, expected_runtime_seconds=latency)
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -82,7 +82,8 @@ def run_shell_command(
|
|||
try:
|
||||
print()
|
||||
output, return_code = run_interactive_command(
|
||||
["/bin/bash", "-c", command], expected_runtime_seconds=timeout
|
||||
["/bin/bash", "-c", command],
|
||||
expected_runtime_seconds=timeout,
|
||||
)
|
||||
print()
|
||||
result = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Unit tests for agent_utils.py."""
|
||||
|
||||
from typing import Any, Dict, Literal
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import litellm
|
||||
|
|
@ -127,7 +128,10 @@ def test_create_agent_openai(mock_model, mock_memory):
|
|||
|
||||
assert agent == "ciayn_agent"
|
||||
mock_ciayn.assert_called_once_with(
|
||||
mock_model, [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"]
|
||||
mock_model,
|
||||
[],
|
||||
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
|
||||
config={"provider": "openai", "model": "gpt-4"},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -141,7 +145,10 @@ def test_create_agent_no_token_limit(mock_model, mock_memory):
|
|||
|
||||
assert agent == "ciayn_agent"
|
||||
mock_ciayn.assert_called_once_with(
|
||||
mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT
|
||||
mock_model,
|
||||
[],
|
||||
max_tokens=DEFAULT_TOKEN_LIMIT,
|
||||
config={"provider": "unknown", "model": "unknown-model"},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -158,6 +165,7 @@ def test_create_agent_missing_config(mock_model, mock_memory):
|
|||
mock_model,
|
||||
[],
|
||||
max_tokens=DEFAULT_TOKEN_LIMIT,
|
||||
config={"provider": "openai"},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -201,7 +209,10 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory):
|
|||
|
||||
assert agent == "ciayn_agent"
|
||||
mock_ciayn.assert_called_once_with(
|
||||
mock_model, [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"]
|
||||
mock_model,
|
||||
[],
|
||||
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
|
||||
config={"provider": "openai", "model": "gpt-4"},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -275,3 +286,118 @@ def test_get_model_token_limit_planner(mock_memory):
|
|||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
assert token_limit == 120000
|
||||
|
||||
|
||||
# New tests for private helper methods in agent_utils.py
|
||||
|
||||
|
||||
def test_setup_and_restore_interrupt_handling():
|
||||
import signal
|
||||
|
||||
from ra_aid.agent_utils import (
|
||||
_request_interrupt,
|
||||
_restore_interrupt_handling,
|
||||
_setup_interrupt_handling,
|
||||
)
|
||||
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
handler = _setup_interrupt_handling()
|
||||
# Verify the SIGINT handler is set to _request_interrupt
|
||||
assert signal.getsignal(signal.SIGINT) == _request_interrupt
|
||||
_restore_interrupt_handling(handler)
|
||||
# Verify the SIGINT handler is restored to the original
|
||||
assert signal.getsignal(signal.SIGINT) == original_handler
|
||||
|
||||
|
||||
def test_increment_and_decrement_agent_depth():
|
||||
from ra_aid.agent_utils import (
|
||||
_decrement_agent_depth,
|
||||
_global_memory,
|
||||
_increment_agent_depth,
|
||||
)
|
||||
|
||||
_global_memory["agent_depth"] = 10
|
||||
_increment_agent_depth()
|
||||
assert _global_memory["agent_depth"] == 11
|
||||
_decrement_agent_depth()
|
||||
assert _global_memory["agent_depth"] == 10
|
||||
|
||||
|
||||
def test_run_agent_stream(monkeypatch):
|
||||
from ra_aid.agent_utils import _global_memory, _run_agent_stream
|
||||
|
||||
# Create a dummy agent that yields one chunk
|
||||
class DummyAgent:
|
||||
def stream(self, input_data, cfg: dict):
|
||||
yield {"content": "chunk1"}
|
||||
|
||||
dummy_agent = DummyAgent()
|
||||
# Set flags so that _run_agent_stream will reset them
|
||||
_global_memory["plan_completed"] = True
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = "existing"
|
||||
call_flag = {"called": False}
|
||||
|
||||
def fake_print_agent_output(
|
||||
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
|
||||
):
|
||||
call_flag["called"] = True
|
||||
|
||||
monkeypatch.setattr(
|
||||
"ra_aid.agent_utils.print_agent_output", fake_print_agent_output
|
||||
)
|
||||
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
|
||||
assert call_flag["called"]
|
||||
assert _global_memory["plan_completed"] is False
|
||||
assert _global_memory["task_completed"] is False
|
||||
assert _global_memory["completion_message"] == ""
|
||||
|
||||
|
||||
def test_execute_test_command_wrapper(monkeypatch):
|
||||
from ra_aid.agent_utils import _execute_test_command_wrapper
|
||||
|
||||
# Patch execute_test_command to return a testable tuple
|
||||
def fake_execute(config, orig, tests, auto):
|
||||
return (True, "new prompt", auto, tests + 1)
|
||||
|
||||
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
|
||||
result = _execute_test_command_wrapper("orig", {}, 0, False)
|
||||
assert result == (True, "new prompt", False, 1)
|
||||
|
||||
|
||||
def test_handle_api_error_valueerror():
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
||||
# ValueError not containing "code" or "429" should be re-raised
|
||||
with pytest.raises(ValueError):
|
||||
_handle_api_error(ValueError("some error"), 0, 5, 1)
|
||||
|
||||
|
||||
def test_handle_api_error_max_retries():
|
||||
import pytest
|
||||
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
||||
# When attempt reaches max retries, a RuntimeError should be raised
|
||||
with pytest.raises(RuntimeError):
|
||||
_handle_api_error(Exception("error code 429"), 4, 5, 1)
|
||||
|
||||
|
||||
def test_handle_api_error_retry(monkeypatch):
|
||||
import time
|
||||
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
||||
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
|
||||
fake_time = [0]
|
||||
|
||||
def fake_monotonic():
|
||||
fake_time[0] += 0.5
|
||||
return fake_time[0]
|
||||
|
||||
monkeypatch.setattr(time, "monotonic", fake_monotonic)
|
||||
monkeypatch.setattr(time, "sleep", lambda s: None)
|
||||
# Should not raise error when attempt is lower than max retries
|
||||
_handle_api_error(Exception("error code 429"), 0, 5, 1)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,38 @@
|
|||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
||||
|
||||
# Dummy tool function for testing retry and fallback behavior
|
||||
def dummy_tool():
|
||||
dummy_tool.attempt += 1
|
||||
if dummy_tool.attempt < 3:
|
||||
raise Exception("Simulated failure")
|
||||
return "dummy success"
|
||||
|
||||
|
||||
dummy_tool.attempt = 0
|
||||
|
||||
|
||||
class DummyTool:
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
|
||||
class DummyModel:
|
||||
def invoke(self, _messages: list[BaseMessage]):
|
||||
return AIMessage("dummy_tool()")
|
||||
|
||||
def bind_tools(self, tools, tool_choice):
|
||||
pass
|
||||
|
||||
|
||||
# Fixtures from the source file
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Create a mock language model."""
|
||||
|
|
@ -21,6 +48,7 @@ def agent(mock_model):
|
|||
return CiaynAgent(mock_model, tools, max_history_messages=3)
|
||||
|
||||
|
||||
# Trimming test functions
|
||||
def test_trim_chat_history_preserves_initial_messages(agent):
|
||||
"""Test that initial messages are preserved during trimming."""
|
||||
initial_messages = [
|
||||
|
|
@ -33,9 +61,7 @@ def test_trim_chat_history_preserves_initial_messages(agent):
|
|||
HumanMessage(content="Chat 3"),
|
||||
AIMessage(content="Chat 4"),
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify initial messages are preserved
|
||||
assert result[:2] == initial_messages
|
||||
# Verify only last 3 chat messages are kept (due to max_history_messages=3)
|
||||
|
|
@ -47,9 +73,7 @@ def test_trim_chat_history_under_limit(agent):
|
|||
"""Test trimming when chat history is under the maximum limit."""
|
||||
initial_messages = [HumanMessage(content="Initial")]
|
||||
chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify no trimming occurred
|
||||
assert len(result) == 3
|
||||
assert result == initial_messages + chat_history
|
||||
|
|
@ -65,9 +89,7 @@ def test_trim_chat_history_over_limit(agent):
|
|||
AIMessage(content="Chat 4"),
|
||||
HumanMessage(content="Chat 5"),
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify correct trimming
|
||||
assert len(result) == 4 # initial + max_history_messages
|
||||
assert result[0] == initial_messages[0] # Initial message preserved
|
||||
|
|
@ -83,9 +105,7 @@ def test_trim_chat_history_empty_initial(agent):
|
|||
HumanMessage(content="Chat 3"),
|
||||
AIMessage(content="Chat 4"),
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify only last 3 messages are kept
|
||||
assert len(result) == 3
|
||||
assert result == chat_history[-3:]
|
||||
|
|
@ -98,9 +118,7 @@ def test_trim_chat_history_empty_chat(agent):
|
|||
AIMessage(content="Initial 2"),
|
||||
]
|
||||
chat_history = []
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify initial messages are preserved and no trimming occurred
|
||||
assert result == initial_messages
|
||||
assert len(result) == 2
|
||||
|
|
@ -109,16 +127,13 @@ def test_trim_chat_history_empty_chat(agent):
|
|||
def test_trim_chat_history_token_limit():
|
||||
"""Test trimming based on token limit."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=25)
|
||||
|
||||
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 40), # ~10 tokens
|
||||
AIMessage(content="B" * 40), # ~10 tokens
|
||||
HumanMessage(content="C" * 40), # ~10 tokens
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should keep initial message (~2 tokens) and last message (~10 tokens)
|
||||
assert len(result) == 2
|
||||
assert result[0] == initial_messages[0]
|
||||
|
|
@ -128,16 +143,13 @@ def test_trim_chat_history_token_limit():
|
|||
def test_trim_chat_history_no_token_limit():
|
||||
"""Test trimming with no token limit set."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
|
||||
|
||||
initial_messages = [HumanMessage(content="Initial")]
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 1000),
|
||||
AIMessage(content="B" * 1000),
|
||||
HumanMessage(content="C" * 1000),
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should keep initial message and last 2 messages (max_history_messages=2)
|
||||
assert len(result) == 3
|
||||
assert result[0] == initial_messages[0]
|
||||
|
|
@ -147,7 +159,6 @@ def test_trim_chat_history_no_token_limit():
|
|||
def test_trim_chat_history_both_limits():
|
||||
"""Test trimming with both message count and token limits."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=35)
|
||||
|
||||
initial_messages = [HumanMessage(content="Init")] # ~1 token
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 40), # ~10 tokens
|
||||
|
|
@ -155,9 +166,7 @@ def test_trim_chat_history_both_limits():
|
|||
HumanMessage(content="C" * 40), # ~10 tokens
|
||||
AIMessage(content="D" * 40), # ~10 tokens
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should first apply message limit (keeping last 3)
|
||||
# Then token limit should further reduce to fit under 15 tokens
|
||||
assert len(result) == 2 # Initial message + 1 message under token limit
|
||||
|
|
@ -165,6 +174,38 @@ def test_trim_chat_history_both_limits():
|
|||
assert result[1] == chat_history[-1]
|
||||
|
||||
|
||||
# Fallback tests
|
||||
class TestCiaynAgentFallback(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Reset dummy_tool attempt counter before each test
|
||||
dummy_tool.attempt = 0
|
||||
self.dummy_tool = DummyTool(dummy_tool)
|
||||
self.model = DummyModel()
|
||||
# Create a CiaynAgent with the dummy tool
|
||||
self.agent = CiaynAgent(self.model, [self.dummy_tool])
|
||||
|
||||
# def test_retry_logic_with_failure_recovery(self):
|
||||
# # Test that run_agent_with_retry retries until success
|
||||
# from ra_aid.agent_utils import run_agent_with_retry
|
||||
#
|
||||
# config = {"max_test_cmd_retries": 0, "auto_test": True}
|
||||
# result = run_agent_with_retry(self.agent, "dummy_tool()", config)
|
||||
# self.assertEqual(result, "Agent run completed successfully")
|
||||
|
||||
def test_switch_models_on_fallback(self):
|
||||
# Test fallback behavior by making dummy_tool always fail
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
def always_fail():
|
||||
raise Exception("Persistent failure")
|
||||
|
||||
always_fail_tool = DummyTool(always_fail)
|
||||
agent = CiaynAgent(self.model, [always_fail_tool])
|
||||
with self.assertRaises(ToolExecutionError):
|
||||
agent._execute_tool(HumanMessage("always_fail()"))
|
||||
|
||||
|
||||
# Function call validation tests
|
||||
class TestFunctionCallValidation:
|
||||
@pytest.mark.parametrize(
|
||||
"test_input",
|
||||
|
|
@ -221,3 +262,11 @@ class TestFunctionCallValidation:
|
|||
def test_multiline_responses(self, test_input):
|
||||
"""Test function calls spanning multiple lines."""
|
||||
assert not validate_function_call_pattern(test_input)
|
||||
|
||||
|
||||
class TestCiaynAgentNewMethods(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,372 @@
|
|||
import unittest
|
||||
|
||||
from ra_aid.exceptions import FallbackToolExecutionError
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def debug(self, msg):
|
||||
pass
|
||||
|
||||
def error(self, msg):
|
||||
pass
|
||||
|
||||
|
||||
class DummyAgent:
|
||||
provider = "openai"
|
||||
tools = []
|
||||
model = None
|
||||
|
||||
|
||||
class TestFallbackHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.config = {
|
||||
"max_tool_failures": 2,
|
||||
"fallback_tool_models": "dummy-fallback-model",
|
||||
"experimental_fallback_handler": True,
|
||||
}
|
||||
self.fallback_handler = FallbackHandler(self.config, [])
|
||||
self.logger = DummyLogger()
|
||||
self.agent = DummyAgent()
|
||||
|
||||
def dummy_tool():
|
||||
pass
|
||||
|
||||
class DummyToolWrapper:
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
self.agent.tools = [DummyToolWrapper(dummy_tool)]
|
||||
|
||||
def test_handle_failure_increments_counter(self):
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
||||
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
|
||||
error_obj = ToolExecutionError(
|
||||
"Test error", base_message="dummy_call()", tool_name="dummy_tool"
|
||||
)
|
||||
self.fallback_handler.handle_failure(error_obj, self.agent, [])
|
||||
self.assertEqual(
|
||||
self.fallback_handler.tool_failure_consecutive_failures,
|
||||
initial_failures + 1,
|
||||
)
|
||||
|
||||
def test_attempt_fallback_resets_counter(self):
|
||||
# Monkey-patch dummy functions for fallback components
|
||||
def dummy_initialize_llm(provider, model_name, temperature=None):
|
||||
class DummyModel:
|
||||
def bind_tools(self, tools, tool_choice):
|
||||
pass
|
||||
|
||||
return DummyModel()
|
||||
|
||||
def dummy_validate_provider_env(provider):
|
||||
return True
|
||||
|
||||
import ra_aid.llm as llm
|
||||
|
||||
original_initialize = llm.initialize_llm
|
||||
original_validate = llm.validate_provider_env
|
||||
llm.initialize_llm = dummy_initialize_llm
|
||||
llm.validate_provider_env = dummy_validate_provider_env
|
||||
|
||||
self.fallback_handler.tool_failure_consecutive_failures = 2
|
||||
with self.assertRaises(FallbackToolExecutionError):
|
||||
self.fallback_handler.attempt_fallback()
|
||||
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
|
||||
|
||||
llm.initialize_llm = original_initialize
|
||||
llm.validate_provider_env = original_validate
|
||||
|
||||
def test_load_fallback_tool_models(self):
|
||||
import ra_aid.fallback_handler as fh
|
||||
|
||||
original_supported = fh.supported_top_tool_models
|
||||
fh.supported_top_tool_models = [
|
||||
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
|
||||
]
|
||||
models = self.fallback_handler._load_fallback_tool_models(self.config)
|
||||
self.assertIsInstance(models, list)
|
||||
fh.supported_top_tool_models = original_supported
|
||||
|
||||
def test_extract_failed_tool_name(self):
|
||||
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
|
||||
|
||||
# Case when tool_name is provided
|
||||
error1 = ToolExecutionError(
|
||||
"Error", base_message="dummy", tool_name="dummy_tool"
|
||||
)
|
||||
name1 = self.fallback_handler.extract_failed_tool_name(error1)
|
||||
self.assertEqual(name1, "dummy_tool")
|
||||
# Case when tool_name is not provided but regex works
|
||||
error2 = ToolExecutionError('error with name="test_tool"')
|
||||
name2 = self.fallback_handler.extract_failed_tool_name(error2)
|
||||
self.assertEqual(name2, "test_tool")
|
||||
# Case when regex fails and exception is raised
|
||||
error3 = ToolExecutionError("no tool name here")
|
||||
with self.assertRaises(FallbackToolExecutionError):
|
||||
self.fallback_handler.extract_failed_tool_name(error3)
|
||||
|
||||
def test_find_tool_to_bind(self):
|
||||
class DummyWrapper:
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def dummy_func(_args):
|
||||
return "result"
|
||||
|
||||
dummy_wrapper = DummyWrapper(dummy_func)
|
||||
self.agent.tools.append(dummy_wrapper)
|
||||
tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__)
|
||||
self.assertIsNotNone(tool)
|
||||
self.assertEqual(tool.func.__name__, dummy_func.__name__)
|
||||
|
||||
def test_bind_tool_model(self):
|
||||
# Setup a dummy simple_model with bind_tools method
|
||||
class DummyModel:
|
||||
def bind_tools(self, tools, tool_choice=None):
|
||||
self.bound = True
|
||||
self.tools = tools
|
||||
self.tool_choice = tool_choice
|
||||
return self
|
||||
|
||||
def with_retry(self, stop_after_attempt):
|
||||
return self
|
||||
|
||||
def invoke(self, msg_list):
|
||||
return "dummy_response"
|
||||
|
||||
dummy_model = DummyModel()
|
||||
|
||||
# Set current tool for binding
|
||||
class DummyTool:
|
||||
def invoke(self, args):
|
||||
return "result"
|
||||
|
||||
self.fallback_handler.current_tool_to_bind = DummyTool()
|
||||
self.fallback_handler.current_failing_tool_name = "test_tool"
|
||||
# Test with force calling ("fc") type
|
||||
fallback_model_fc = {"type": "fc"}
|
||||
bound_model_fc = self.fallback_handler._bind_tool_model(
|
||||
dummy_model, fallback_model_fc
|
||||
)
|
||||
self.assertTrue(hasattr(bound_model_fc, "tool_choice"))
|
||||
self.assertEqual(bound_model_fc.tool_choice, "test_tool")
|
||||
# Test with prompt type
|
||||
fallback_model_prompt = {"type": "prompt"}
|
||||
bound_model_prompt = self.fallback_handler._bind_tool_model(
|
||||
dummy_model, fallback_model_prompt
|
||||
)
|
||||
self.assertTrue(bound_model_prompt.tool_choice is None)
|
||||
|
||||
def test_invoke_fallback(self):
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
# Successful fallback scenario with proper API key set
|
||||
with (
|
||||
patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}),
|
||||
patch(
|
||||
"ra_aid.fallback_handler.supported_top_tool_models",
|
||||
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
|
||||
),
|
||||
patch("ra_aid.fallback_handler.validate_provider_env", return_value=True),
|
||||
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
|
||||
):
|
||||
|
||||
class DummyModel:
|
||||
def bind_tools(self, tools, tool_choice=None):
|
||||
return self
|
||||
|
||||
def with_retry(self, stop_after_attempt):
|
||||
return self
|
||||
|
||||
def invoke(self, msg_list):
|
||||
return DummyResponse()
|
||||
|
||||
class DummyResponse:
|
||||
additional_kwargs = {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "1",
|
||||
"type": "test",
|
||||
"function": {"name": "dummy_tool", "arguments": '{"a":1}'},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def dummy_initialize_llm(provider, model_name):
|
||||
return DummyModel()
|
||||
|
||||
mock_init_llm.side_effect = dummy_initialize_llm
|
||||
|
||||
# Set current tool for fallback
|
||||
class DummyTool:
|
||||
def invoke(self, args):
|
||||
return "tool_result"
|
||||
|
||||
self.fallback_handler.current_tool_to_bind = DummyTool()
|
||||
self.fallback_handler.current_failing_tool_name = "dummy_tool"
|
||||
# Add dummy tool for lookup in invoke_prompt_tool_call
|
||||
self.fallback_handler.tools.append(
|
||||
type(
|
||||
"DummyToolWrapper",
|
||||
(),
|
||||
{
|
||||
"func": type("DummyToolFunc", (), {"__name__": "dummy_tool"})(),
|
||||
"invoke": lambda self, args=None: "tool_result",
|
||||
},
|
||||
)
|
||||
)
|
||||
result = self.fallback_handler.invoke_fallback(
|
||||
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
|
||||
)
|
||||
self.assertIsInstance(result, list)
|
||||
self.assertEqual(result[1], "tool_result")
|
||||
|
||||
# Failed fallback scenario due to missing API key (simulate by empty environment)
|
||||
with (
|
||||
patch.dict(os.environ, {}, clear=True),
|
||||
patch(
|
||||
"ra_aid.fallback_handler.supported_top_tool_models",
|
||||
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
|
||||
),
|
||||
patch("ra_aid.fallback_handler.validate_provider_env", return_value=False),
|
||||
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
|
||||
):
|
||||
|
||||
class FailingDummyModel:
|
||||
def bind_tools(self, tools, tool_choice=None):
|
||||
return self
|
||||
|
||||
def with_retry(self, stop_after_attempt):
|
||||
return self
|
||||
|
||||
def invoke(self, msg_list):
|
||||
raise Exception("API key missing")
|
||||
|
||||
def failing_initialize_llm(provider, model_name):
|
||||
return FailingDummyModel()
|
||||
|
||||
mock_init_llm.side_effect = failing_initialize_llm
|
||||
fallback_result = self.fallback_handler.invoke_fallback(
|
||||
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
|
||||
)
|
||||
self.assertIsNone(fallback_result)
|
||||
|
||||
# Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail
|
||||
# Set failure count to trigger the fallback attempt in attempt_fallback
|
||||
from ra_aid.exceptions import FallbackToolExecutionError
|
||||
|
||||
self.fallback_handler.tool_failure_consecutive_failures = (
|
||||
self.fallback_handler.max_failures
|
||||
)
|
||||
with self.assertRaises(FallbackToolExecutionError) as cm:
|
||||
self.fallback_handler.attempt_fallback()
|
||||
self.assertIn("All fallback models have failed", str(cm.exception))
|
||||
|
||||
def test_construct_prompt_msg_list(self):
|
||||
msgs = self.fallback_handler.construct_prompt_msg_list()
|
||||
from ra_aid.fallback_handler import HumanMessage, SystemMessage
|
||||
|
||||
self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs))
|
||||
self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs))
|
||||
# Test with failed_messages added
|
||||
self.fallback_handler.failed_messages.append("failed_msg")
|
||||
msgs_with_fail = self.fallback_handler.construct_prompt_msg_list()
|
||||
self.assertTrue(any("failed_msg" in str(m) for m in msgs_with_fail))
|
||||
|
||||
def test_invoke_prompt_tool_call(self):
|
||||
# Create dummy tool function
|
||||
def dummy_tool_func(args):
|
||||
return "invoked_result"
|
||||
|
||||
dummy_tool_func.__name__ = "dummy_tool"
|
||||
|
||||
# Create wrapper class
|
||||
class DummyToolWrapper:
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def invoke(self, args):
|
||||
return self.func(args)
|
||||
|
||||
dummy_wrapper = DummyToolWrapper(dummy_tool_func)
|
||||
self.fallback_handler.tools = [dummy_wrapper]
|
||||
tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}}
|
||||
result = self.fallback_handler.invoke_prompt_tool_call(tool_call_req)
|
||||
self.assertEqual(result, "invoked_result")
|
||||
|
||||
def test_base_message_to_tool_call_dict(self):
|
||||
dummy_tool_call = {
|
||||
"id": "123",
|
||||
"type": "test",
|
||||
"function": {"name": "dummy_tool", "arguments": '{"x":42}'},
|
||||
}
|
||||
DummyResponse = type(
|
||||
"DummyResponse",
|
||||
(),
|
||||
{"additional_kwargs": {"tool_calls": [dummy_tool_call]}},
|
||||
)
|
||||
result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse)
|
||||
self.assertEqual(result["id"], "123")
|
||||
self.assertEqual(result["name"], "dummy_tool")
|
||||
self.assertEqual(result["arguments"], {"x": 42})
|
||||
|
||||
def test_parse_tool_arguments(self):
|
||||
args_str = '{"a": 1}'
|
||||
parsed = self.fallback_handler._parse_tool_arguments(args_str)
|
||||
self.assertEqual(parsed, {"a": 1})
|
||||
args_dict = {"b": 2}
|
||||
parsed_dict = self.fallback_handler._parse_tool_arguments(args_dict)
|
||||
self.assertEqual(parsed_dict, {"b": 2})
|
||||
|
||||
def test_get_tool_calls(self):
|
||||
DummyResponse = type("DummyResponse", (), {})()
|
||||
DummyResponse.additional_kwargs = {"tool_calls": [{"id": "1"}]}
|
||||
calls = self.fallback_handler.get_tool_calls(DummyResponse)
|
||||
self.assertEqual(calls, [{"id": "1"}])
|
||||
DummyResponse2 = type("DummyResponse2", (), {"tool_calls": [{"id": "2"}]})()
|
||||
calls2 = self.fallback_handler.get_tool_calls(DummyResponse2)
|
||||
self.assertEqual(calls2, [{"id": "2"}])
|
||||
dummy_dict = {"additional_kwargs": {"tool_calls": [{"id": "3"}]}}
|
||||
calls3 = self.fallback_handler.get_tool_calls(dummy_dict)
|
||||
self.assertEqual(calls3, [{"id": "3"}])
|
||||
|
||||
def test_handle_failure_response(self):
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
||||
def dummy_handle_failure(error, agent):
|
||||
return ["fallback_response"]
|
||||
|
||||
self.fallback_handler.handle_failure = dummy_handle_failure
|
||||
response = self.fallback_handler.handle_failure_response(
|
||||
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React"
|
||||
)
|
||||
from ra_aid.fallback_handler import SystemMessage
|
||||
|
||||
self.assertTrue(all(isinstance(m, SystemMessage) for m in response))
|
||||
response_non = self.fallback_handler.handle_failure_response(
|
||||
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other"
|
||||
)
|
||||
self.assertIsNone(response_non)
|
||||
|
||||
def test_init_msg_list_non_overlapping(self):
|
||||
# Test when the first two and last two messages do not overlap.
|
||||
full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"]
|
||||
self.fallback_handler.init_msg_list(full_list)
|
||||
# Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5")
|
||||
self.assertEqual(
|
||||
self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]
|
||||
)
|
||||
|
||||
def test_init_msg_list_with_overlap(self):
|
||||
# Test when the last two messages overlap with the first two.
|
||||
full_list = ["msg1", "msg2", "msg1", "msg3"]
|
||||
self.fallback_handler.init_msg_list(full_list)
|
||||
# Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present.
|
||||
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -142,7 +142,9 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
|
|||
|
||||
def test_initialize_expert_unsupported_provider(clean_env):
|
||||
"""Test error handling for unsupported provider in expert mode."""
|
||||
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
|
||||
with pytest.raises(
|
||||
ValueError, match=r"Missing required environment variable for provider: unknown"
|
||||
):
|
||||
initialize_expert_llm("unknown", "model")
|
||||
|
||||
|
||||
|
|
@ -235,7 +237,9 @@ def test_initialize_openai_compatible(clean_env, mock_openai):
|
|||
|
||||
def test_initialize_unsupported_provider(clean_env):
|
||||
"""Test initialization with unsupported provider raises ValueError"""
|
||||
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
|
||||
with pytest.raises(
|
||||
ValueError, match=r"Missing required environment variable for provider: unknown"
|
||||
):
|
||||
initialize_llm("unknown", "model")
|
||||
|
||||
|
||||
|
|
@ -257,15 +261,33 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin
|
|||
max_retries=5,
|
||||
)
|
||||
|
||||
# Test error when no temperature provided for models that support it
|
||||
with pytest.raises(ValueError, match="Temperature must be provided for model"):
|
||||
initialize_llm("openai", "test-model")
|
||||
# Test default temperature when none is provided for models that support it
|
||||
initialize_llm("openai", "test-model")
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
model="test-model",
|
||||
temperature=0.7,
|
||||
timeout=180,
|
||||
max_retries=5,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Temperature must be provided for model"):
|
||||
initialize_llm("anthropic", "test-model")
|
||||
initialize_llm("anthropic", "test-model")
|
||||
mock_anthropic.assert_called_with(
|
||||
api_key="test-key",
|
||||
model_name="test-model",
|
||||
temperature=0.7,
|
||||
timeout=180,
|
||||
max_retries=5,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Temperature must be provided for model"):
|
||||
initialize_llm("gemini", "test-model")
|
||||
initialize_llm("gemini", "test-model")
|
||||
mock_gemini.assert_called_with(
|
||||
api_key="test-key",
|
||||
model="test-model",
|
||||
temperature=0.7,
|
||||
timeout=180,
|
||||
max_retries=5,
|
||||
)
|
||||
|
||||
# Test expert models don't require temperature
|
||||
initialize_expert_llm("openai", "o1")
|
||||
|
|
|
|||
Loading…
Reference in New Issue