feat(fallback): implement fallback handler for tool execution errors to enhance error resilience and user experience

refactor(fallback): streamline fallback model selection and invocation process for improved maintainability
fix(config): reduce maximum tool failures from 3 to 2 to tighten error handling thresholds
style(console): improve error message formatting and logging for better clarity and debugging
chore(main): remove redundant fallback tool model handling from main function to simplify configuration management
This commit is contained in:
Ariel Frischer 2025-02-11 18:35:34 -08:00
parent 1388067769
commit 67ecf72a6c
7 changed files with 260 additions and 136 deletions

View File

@ -427,15 +427,6 @@ def main():
_global_memory["config"]["planner_model"] = args.planner_model or args.model _global_memory["config"]["planner_model"] = args.planner_model or args.model
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
_global_memory["config"]["fallback_tool_models"] = (
[
model.strip()
for model in args.fallback_tool_models.split(",")
if model.strip()
]
if args.fallback_tool_models
else []
)
# Store research config with fallback to base values # Store research config with fallback to base values
_global_memory["config"]["research_provider"] = ( _global_memory["config"]["research_provider"] = (
@ -445,15 +436,6 @@ def main():
# Store fallback tool configuration # Store fallback tool configuration
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool _global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
_global_memory["config"]["fallback_tool_models"] = (
[
model.strip()
for model in args.fallback_tool_models.split(",")
if model.strip()
]
if args.fallback_tool_models
else []
)
# Run research stage # Run research stage
print_stage_header("Research Stage") print_stage_header("Research Stage")

View File

@ -16,7 +16,6 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.messages import ( from langchain_core.messages import (
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
InvalidToolCall,
trim_messages, trim_messages,
) )
from langchain_core.tools import tool from langchain_core.tools import tool
@ -339,9 +338,6 @@ def run_research_agent(
if memory is None: if memory is None:
memory = MemorySaver() memory = MemorySaver()
if thread_id is None:
thread_id = str(uuid.uuid4())
tools = get_research_tools( tools = get_research_tools(
research_only=research_only, research_only=research_only,
expert_enabled=expert_enabled, expert_enabled=expert_enabled,
@ -413,7 +409,8 @@ def run_research_agent(
if agent is not None: if agent is not None:
logger.debug("Research agent completed successfully") logger.debug("Research agent completed successfully")
_result = run_agent_with_retry(agent, prompt, run_config) fallback_handler = FallbackHandler(config, tools)
_result = run_agent_with_retry(agent, prompt, run_config, fallback_handler)
if _result: if _result:
# Log research completion # Log research completion
log_work_event(f"Completed research phase for: {base_task_or_query}") log_work_event(f"Completed research phase for: {base_task_or_query}")
@ -529,7 +526,8 @@ def run_web_research_agent(
console.print(Panel(Markdown(console_message), title="🔬 Researching...")) console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
logger.debug("Web research agent completed successfully") logger.debug("Web research agent completed successfully")
_result = run_agent_with_retry(agent, prompt, run_config) fallback_handler = FallbackHandler(config, tools)
_result = run_agent_with_retry(agent, prompt, run_config, fallback_handler)
if _result: if _result:
# Log web research completion # Log web research completion
log_work_event(f"Completed web research phase for: {query}") log_work_event(f"Completed web research phase for: {query}")
@ -634,7 +632,10 @@ def run_planning_agent(
try: try:
print_stage_header("Planning Stage") print_stage_header("Planning Stage")
logger.debug("Planning agent completed successfully") logger.debug("Planning agent completed successfully")
_result = run_agent_with_retry(agent, planning_prompt, run_config) fallback_handler = FallbackHandler(config, tools)
_result = run_agent_with_retry(
agent, planning_prompt, run_config, fallback_handler
)
if _result: if _result:
# Log planning completion # Log planning completion
log_work_event(f"Completed planning phase for: {base_task}") log_work_event(f"Completed planning phase for: {base_task}")
@ -739,7 +740,8 @@ def run_task_implementation_agent(
try: try:
logger.debug("Implementation agent completed successfully") logger.debug("Implementation agent completed successfully")
_result = run_agent_with_retry(agent, prompt, run_config) fallback_handler = FallbackHandler(config, tools)
_result = run_agent_with_retry(agent, prompt, run_config, fallback_handler)
if _result: if _result:
# Log task implementation completion # Log task implementation completion
log_work_event(f"Completed implementation of task: {task}") log_work_event(f"Completed implementation of task: {task}")
@ -805,7 +807,7 @@ def _decrement_agent_depth():
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
def _run_agent_stream(agent, prompt, config): def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict):
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
logger.debug("Agent output: %s", chunk) logger.debug("Agent output: %s", chunk)
check_interrupt() check_interrupt()
@ -840,7 +842,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
time.sleep(0.1) time.sleep(0.1)
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: def run_agent_with_retry(
agent, prompt: str, config: dict, fallback_handler: FallbackHandler
) -> Optional[str]:
"""Run an agent with retry logic for API errors.""" """Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt)) logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = _setup_interrupt_handling() original_handler = _setup_interrupt_handling()
@ -850,7 +854,6 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
auto_test = config.get("auto_test", False) auto_test = config.get("auto_test", False)
original_prompt = prompt original_prompt = prompt
fallback_handler = FallbackHandler(config)
with InterruptibleSection(): with InterruptibleSection():
try: try:
@ -872,8 +875,13 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
continue continue
logger.debug("Agent run completed successfully") logger.debug("Agent run completed successfully")
return "Agent run completed successfully" return "Agent run completed successfully"
except (ToolExecutionError, InvalidToolCall) as e: except ToolExecutionError as e:
_handle_tool_execution_error(fallback_handler, agent, e) fallback_response = _handle_tool_execution_error(
fallback_handler, agent, e
)
if fallback_response:
prompt = original_prompt + "\n" + fallback_response
continue
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise
except ( except (
@ -892,6 +900,37 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
def _handle_tool_execution_error( def _handle_tool_execution_error(
fallback_handler: FallbackHandler, fallback_handler: FallbackHandler,
agent: CiaynAgent | CompiledGraph, agent: CiaynAgent | CompiledGraph,
error: ToolExecutionError | InvalidToolCall, error: ToolExecutionError,
): ):
fallback_handler.handle_failure("Tool execution error", error, agent) logger.debug("Entering _handle_tool_execution_error with error: %s", error)
if error.tool_name:
failed_tool_call_name = error.tool_name
logger.debug(
"Extracted failed_tool_call_name from error.tool_name: %s",
failed_tool_call_name,
)
else:
import re
msg = str(error)
logger.debug("Error message: %s", msg)
match = re.search(r"name=['\"](\w+)['\"]", msg)
if match:
failed_tool_call_name = match.group(1)
logger.debug(
"Extracted failed_tool_call_name using regex: %s", failed_tool_call_name
)
else:
failed_tool_call_name = "Tool execution error"
logger.debug(
"Defaulting failed_tool_call_name to: %s", failed_tool_call_name
)
logger.debug(
"Calling fallback_handler.handle_failure with failed_tool_call_name: %s",
failed_tool_call_name,
)
fallback_response = fallback_handler.handle_failure(
failed_tool_call_name, error, agent
)
logger.debug("Fallback response received: %s", fallback_response)
return fallback_response

View File

@ -2,7 +2,7 @@
DEFAULT_RECURSION_LIMIT = 100 DEFAULT_RECURSION_LIMIT = 100
DEFAULT_MAX_TEST_CMD_RETRIES = 3 DEFAULT_MAX_TEST_CMD_RETRIES = 3
DEFAULT_MAX_TOOL_FAILURES = 3 DEFAULT_MAX_TOOL_FAILURES = 2
FALLBACK_TOOL_MODEL_LIMIT = 5 FALLBACK_TOOL_MODEL_LIMIT = 5
RETRY_FALLBACK_COUNT = 3 RETRY_FALLBACK_COUNT = 3
RETRY_FALLBACK_DELAY = 2 RETRY_FALLBACK_DELAY = 2

View File

@ -1,9 +1,11 @@
from typing import Any, Dict from typing import Any, Dict, Optional
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
from ra_aid.exceptions import ToolExecutionError
# Import shared console instance # Import shared console instance
from .formatting import console from .formatting import console
@ -33,10 +35,26 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
elif "tools" in chunk and "messages" in chunk["tools"]: elif "tools" in chunk and "messages" in chunk["tools"]:
for msg in chunk["tools"]["messages"]: for msg in chunk["tools"]["messages"]:
if msg.status == "error" and msg.content: if msg.status == "error" and msg.content:
err_msg = msg.content.strip()
console.print( console.print(
Panel( Panel(
Markdown(msg.content.strip()), Markdown(err_msg),
title="❌ Tool Error", title="❌ Tool Error",
border_style="red bold", border_style="red bold",
) )
) )
tool_name = getattr(msg, "name", None)
raise ToolExecutionError(err_msg, tool_name=tool_name)
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
"""
Print a message using a Panel with Markdown formatting.
Args:
message (str): The message content to display.
title (Optional[str]): An optional title for the panel.
border_style (str): Border style for the panel.
"""
console.print(Panel(Markdown(message), title=title, border_style=border_style))

View File

@ -17,5 +17,6 @@ class ToolExecutionError(Exception):
This exception is used to distinguish tool execution failures This exception is used to distinguish tool execution failures
from other types of errors in the agent system. from other types of errors in the agent system.
""" """
def __init__(self, message, tool_name=None):
pass super().__init__(message)
self.tool_name = tool_name

View File

@ -1,13 +1,24 @@
from typing import Dict
from langchain_core.tools import BaseTool
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import BaseMessage
from ra_aid.console.output import cpm
import json
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import ( from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES, DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT, FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT, RETRY_FALLBACK_COUNT,
) )
from ra_aid.logging_config import get_logger
from ra_aid.tool_leaderboard import supported_top_tool_models from ra_aid.tool_leaderboard import supported_top_tool_models
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from ra_aid.llm import initialize_llm, validate_provider_env
from rich.panel import Panel
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env # from langgraph.graph.message import BaseMessage, BaseMessageChunk
# from langgraph.prebuilt import ToolNode
logger = get_logger(__name__) logger = get_logger(__name__)
@ -22,18 +33,21 @@ class FallbackHandler:
counters when a tool call succeeds. counters when a tool call succeeds.
""" """
def __init__(self, config): def __init__(self, config, tools):
""" """
Initialize the FallbackHandler with the given configuration. Initialize the FallbackHandler with the given configuration and tools.
Args: Args:
config (dict): Configuration dictionary that may include fallback settings. config (dict): Configuration dictionary that may include fallback settings.
tools (list): List of available tools.
""" """
self.config = config self.config = config
self.tools: list[BaseTool] = tools
self.fallback_enabled = config.get("fallback_tool_enabled", True) self.fallback_enabled = config.get("fallback_tool_enabled", True)
self.fallback_tool_models = self._load_fallback_tool_models(config) self.fallback_tool_models = self._load_fallback_tool_models(config)
self.tool_failure_consecutive_failures = 0 self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks = set() self.tool_failure_used_fallbacks = set()
self.console = Console()
def _load_fallback_tool_models(self, config): def _load_fallback_tool_models(self, config):
""" """
@ -49,17 +63,6 @@ class FallbackHandler:
Returns: Returns:
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model. list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
""" """
fallback_tool_models_config = config.get("fallback_tool_models")
if fallback_tool_models_config:
# Assume comma-separated model names; wrap each in a dict with default type "prompt"
models = []
for m in [
x.strip() for x in fallback_tool_models_config.split(",") if x.strip()
]:
models.append({"model": m, "type": "prompt"})
return models
else:
console = Console()
supported = [] supported = []
skipped = [] skipped = []
for item in supported_top_tool_models: for item in supported_top_tool_models:
@ -85,10 +88,12 @@ class FallbackHandler:
"\nSkipped top tool calling models due to missing provider ENV API keys: " "\nSkipped top tool calling models due to missing provider ENV API keys: "
+ ", ".join(skipped) + ", ".join(skipped)
) )
console.print(Panel(Markdown(message), title="Fallback Models")) cpm(message, title="Fallback Models")
return final_models return final_models
def handle_failure(self, code: str, error: Exception, agent): def handle_failure(
self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph
):
""" """
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
@ -114,7 +119,7 @@ class FallbackHandler:
logger.debug( logger.debug(
"_handle_tool_failure: threshold reached, invoking fallback mechanism." "_handle_tool_failure: threshold reached, invoking fallback mechanism."
) )
self.attempt_fallback(code, logger, agent) return self.attempt_fallback(code, logger, agent)
def attempt_fallback(self, code: str, logger, agent): def attempt_fallback(self, code: str, logger, agent):
""" """
@ -127,18 +132,14 @@ class FallbackHandler:
""" """
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}") logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
fallback_model = self.fallback_tool_models[0] fallback_model = self.fallback_tool_models[0]
failed_tool_call_name = code.split("(")[0].strip() failed_tool_call_name = code
logger.error( logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}" f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
) )
Console().print( cpm(
Panel( f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.",
Markdown(
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."
),
title="Fallback Notification", title="Fallback Notification",
) )
)
if fallback_model.get("type", "prompt").lower() == "fc": if fallback_model.get("type", "prompt").lower() == "fc":
self.attempt_fallback_function(code, logger, agent) self.attempt_fallback_function(code, logger, agent)
else: else:
@ -151,6 +152,30 @@ class FallbackHandler:
self.tool_failure_consecutive_failures = 0 self.tool_failure_consecutive_failures = 0
self.tool_failure_used_fallbacks.clear() self.tool_failure_used_fallbacks.clear()
def _find_tool_to_bind(self, agent, failed_tool_call_name):
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
tool_to_bind = None
if hasattr(agent, "tools"):
tool_to_bind = next(
(t for t in agent.tools if t.func.__name__ == failed_tool_call_name),
None,
)
if tool_to_bind is None:
from ra_aid.tool_configs import get_all_tools
all_tools = get_all_tools()
tool_to_bind = next(
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
None,
)
if tool_to_bind is None:
available = [t.func.__name__ for t in get_all_tools()]
logger.debug(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}"
)
raise Exception(f"Tool {failed_tool_call_name} not found in all tools.")
return tool_to_bind
def attempt_fallback_prompt(self, code: str, logger, agent): def attempt_fallback_prompt(self, code: str, logger, agent):
""" """
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code. Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
@ -169,42 +194,40 @@ class FallbackHandler:
Exception: If all prompt-based fallback models fail. Exception: If all prompt-based fallback models fail.
""" """
logger.debug("Attempting prompt-based fallback using fallback models") logger.debug("Attempting prompt-based fallback using fallback models")
failed_tool_call_name = code.split("(")[0].strip() failed_tool_call_name = code
for fallback_model in self.fallback_tool_models: for fallback_model in self.fallback_tool_models:
try: try:
logger.debug(f"Trying fallback model: {fallback_model['model']}") logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm( simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"] fallback_model["provider"], fallback_model["model"]
) )
tool_to_bind = next( tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
(
t
for t in agent.tools
if t.func.__name__ == failed_tool_call_name
),
None,
)
if tool_to_bind is None:
logger.debug(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
)
raise Exception(
f"Tool {failed_tool_call_name} not found in agent.tools"
)
binded_model = simple_model.bind_tools( binded_model = simple_model.bind_tools(
[tool_to_bind], tool_choice=failed_tool_call_name [tool_to_bind], tool_choice=failed_tool_call_name
) )
retry_model = binded_model.with_retry( # retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT # stop_after_attempt=RETRY_FALLBACK_COUNT
) # )
response = retry_model.invoke(code) response = binded_model.invoke(code)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"]) self.tool_failure_used_fallbacks.add(fallback_model["model"])
agent.model = retry_model
self.reset_fallback_handler() tool_call = self.base_message_to_tool_call_dict(response)
if tool_call:
result = self.invoke_prompt_tool_call(tool_call)
cpm(f"result={result}")
logger.debug( logger.debug(
"Prompt-based fallback executed successfully with model: " "Prompt-based fallback executed successfully with model: "
+ fallback_model["model"] + fallback_model["model"]
) )
self.reset_fallback_handler()
return result
else:
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response return response
except Exception as e: except Exception as e:
if isinstance(e, KeyboardInterrupt): if isinstance(e, KeyboardInterrupt):
@ -232,28 +255,14 @@ class FallbackHandler:
Exception: If all function-calling fallback models fail. Exception: If all function-calling fallback models fail.
""" """
logger.debug("Attempting function-calling fallback using fallback models") logger.debug("Attempting function-calling fallback using fallback models")
failed_tool_call_name = code.split("(")[0].strip() failed_tool_call_name = code
for fallback_model in self.fallback_tool_models: for fallback_model in self.fallback_tool_models:
try: try:
logger.debug(f"Trying fallback model: {fallback_model['model']}") logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm( simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"] fallback_model["provider"], fallback_model["model"]
) )
tool_to_bind = next( tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
(
t
for t in agent.tools
if t.func.__name__ == failed_tool_call_name
),
None,
)
if tool_to_bind is None:
logger.debug(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
)
raise Exception(
f"Tool {failed_tool_call_name} not found in agent.tools"
)
binded_model = simple_model.bind_tools( binded_model = simple_model.bind_tools(
[tool_to_bind], tool_choice=failed_tool_call_name [tool_to_bind], tool_choice=failed_tool_call_name
) )
@ -261,13 +270,18 @@ class FallbackHandler:
stop_after_attempt=RETRY_FALLBACK_COUNT stop_after_attempt=RETRY_FALLBACK_COUNT
) )
response = retry_model.invoke(code) response = retry_model.invoke(code)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"]) self.tool_failure_used_fallbacks.add(fallback_model["model"])
agent.model = retry_model
self.reset_fallback_handler() self.reset_fallback_handler()
logger.debug( logger.debug(
"Function-calling fallback executed successfully with model: " "Function-calling fallback executed successfully with model: "
+ fallback_model["model"] + fallback_model["model"]
) )
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response return response
except Exception as e: except Exception as e:
if isinstance(e, KeyboardInterrupt): if isinstance(e, KeyboardInterrupt):
@ -276,3 +290,58 @@ class FallbackHandler:
f"Function-calling fallback with model {fallback_model['model']} failed: {e}" f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
) )
raise Exception("All function-calling fallback models failed") raise Exception("All function-calling fallback models failed")
def invoke_prompt_tool_call(self, tool_call_request: dict):
"""
Invoke a tool call from a prompt-based fallback response.
Args:
tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'.
Returns:
The result of invoking the tool.
"""
tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools}
name = tool_call_request["name"]
arguments = tool_call_request["arguments"]
# return tool_name_to_tool[name].invoke(arguments)
# tool_call_dict = {"arguments": arguments}
return tool_name_to_tool[name].invoke(arguments)
def base_message_to_tool_call_dict(self, response: BaseMessage):
"""
Extracts a tool call dictionary from a fallback response.
Args:
response: The response object containing tool call data.
Returns:
A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found,
otherwise None.
"""
tool_calls = None
if hasattr(response, "additional_kwargs") and response.additional_kwargs.get(
"tool_calls"
):
tool_calls = response.additional_kwargs.get("tool_calls")
elif hasattr(response, "tool_calls"):
tool_calls = response.tool_calls
elif isinstance(response, dict) and response.get("additional_kwargs", {}).get(
"tool_calls"
):
tool_calls = response.get("additional_kwargs").get("tool_calls")
if tool_calls:
if len(tool_calls) > 1:
logger.warning("Multiple tool calls detected, using the first one")
tool_call = tool_calls[0]
return {
"id": tool_call["id"],
"type": tool_call["type"],
"name": tool_call["function"]["name"],
"arguments": (
json.loads(tool_call["function"]["arguments"])
if isinstance(tool_call["function"]["arguments"], str)
else tool_call["function"]["arguments"]
),
}
return None

View File

@ -28,7 +28,7 @@ from ra_aid.tools.write_file import write_file_tool
# Read-only tools that don't modify system state # Read-only tools that don't modify system state
def get_read_only_tools( def get_read_only_tools(
human_interaction: bool = False, web_research_enabled: bool = False human_interaction: bool = False, web_research_enabled: bool = False
) -> list: ):
"""Get the list of read-only tools, optionally including human interaction tools. """Get the list of read-only tools, optionally including human interaction tools.
Args: Args:
@ -61,6 +61,21 @@ def get_read_only_tools(
return tools return tools
def get_all_tools_simple():
"""Return a list containing all available tools using existing group methods."""
return get_all_tools()
def get_all_tools():
"""Return a list containing all available tools from different groups."""
all_tools = []
all_tools.extend(get_read_only_tools())
all_tools.extend(MODIFICATION_TOOLS)
all_tools.extend(EXPERT_TOOLS)
all_tools.extend(RESEARCH_TOOLS)
all_tools.extend(get_web_research_tools())
all_tools.extend(get_chat_tools())
return all_tools
# Define constant tool groups # Define constant tool groups
READ_ONLY_TOOLS = get_read_only_tools() READ_ONLY_TOOLS = get_read_only_tools()
@ -81,7 +96,7 @@ def get_research_tools(
expert_enabled: bool = True, expert_enabled: bool = True,
human_interaction: bool = False, human_interaction: bool = False,
web_research_enabled: bool = False, web_research_enabled: bool = False,
) -> list: ):
"""Get the list of research tools based on mode and whether expert is enabled. """Get the list of research tools based on mode and whether expert is enabled.
Args: Args: