feat(fallback): implement fallback handler for tool execution errors to enhance error resilience and user experience
refactor(fallback): streamline fallback model selection and invocation process for improved maintainability fix(config): reduce maximum tool failures from 3 to 2 to tighten error handling thresholds style(console): improve error message formatting and logging for better clarity and debugging chore(main): remove redundant fallback tool model handling from main function to simplify configuration management
This commit is contained in:
parent
1388067769
commit
67ecf72a6c
|
|
@ -427,15 +427,6 @@ def main():
|
||||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||||
|
|
||||||
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
|
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
|
||||||
_global_memory["config"]["fallback_tool_models"] = (
|
|
||||||
[
|
|
||||||
model.strip()
|
|
||||||
for model in args.fallback_tool_models.split(",")
|
|
||||||
if model.strip()
|
|
||||||
]
|
|
||||||
if args.fallback_tool_models
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store research config with fallback to base values
|
# Store research config with fallback to base values
|
||||||
_global_memory["config"]["research_provider"] = (
|
_global_memory["config"]["research_provider"] = (
|
||||||
|
|
@ -445,15 +436,6 @@ def main():
|
||||||
|
|
||||||
# Store fallback tool configuration
|
# Store fallback tool configuration
|
||||||
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
|
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
|
||||||
_global_memory["config"]["fallback_tool_models"] = (
|
|
||||||
[
|
|
||||||
model.strip()
|
|
||||||
for model in args.fallback_tool_models.split(",")
|
|
||||||
if model.strip()
|
|
||||||
]
|
|
||||||
if args.fallback_tool_models
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run research stage
|
# Run research stage
|
||||||
print_stage_header("Research Stage")
|
print_stage_header("Research Stage")
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@ from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
InvalidToolCall,
|
|
||||||
trim_messages,
|
trim_messages,
|
||||||
)
|
)
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
@ -339,9 +338,6 @@ def run_research_agent(
|
||||||
if memory is None:
|
if memory is None:
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
|
|
||||||
if thread_id is None:
|
|
||||||
thread_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
tools = get_research_tools(
|
tools = get_research_tools(
|
||||||
research_only=research_only,
|
research_only=research_only,
|
||||||
expert_enabled=expert_enabled,
|
expert_enabled=expert_enabled,
|
||||||
|
|
@ -413,7 +409,8 @@ def run_research_agent(
|
||||||
|
|
||||||
if agent is not None:
|
if agent is not None:
|
||||||
logger.debug("Research agent completed successfully")
|
logger.debug("Research agent completed successfully")
|
||||||
_result = run_agent_with_retry(agent, prompt, run_config)
|
fallback_handler = FallbackHandler(config, tools)
|
||||||
|
_result = run_agent_with_retry(agent, prompt, run_config, fallback_handler)
|
||||||
if _result:
|
if _result:
|
||||||
# Log research completion
|
# Log research completion
|
||||||
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
||||||
|
|
@ -529,7 +526,8 @@ def run_web_research_agent(
|
||||||
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
|
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
|
||||||
|
|
||||||
logger.debug("Web research agent completed successfully")
|
logger.debug("Web research agent completed successfully")
|
||||||
_result = run_agent_with_retry(agent, prompt, run_config)
|
fallback_handler = FallbackHandler(config, tools)
|
||||||
|
_result = run_agent_with_retry(agent, prompt, run_config, fallback_handler)
|
||||||
if _result:
|
if _result:
|
||||||
# Log web research completion
|
# Log web research completion
|
||||||
log_work_event(f"Completed web research phase for: {query}")
|
log_work_event(f"Completed web research phase for: {query}")
|
||||||
|
|
@ -634,7 +632,10 @@ def run_planning_agent(
|
||||||
try:
|
try:
|
||||||
print_stage_header("Planning Stage")
|
print_stage_header("Planning Stage")
|
||||||
logger.debug("Planning agent completed successfully")
|
logger.debug("Planning agent completed successfully")
|
||||||
_result = run_agent_with_retry(agent, planning_prompt, run_config)
|
fallback_handler = FallbackHandler(config, tools)
|
||||||
|
_result = run_agent_with_retry(
|
||||||
|
agent, planning_prompt, run_config, fallback_handler
|
||||||
|
)
|
||||||
if _result:
|
if _result:
|
||||||
# Log planning completion
|
# Log planning completion
|
||||||
log_work_event(f"Completed planning phase for: {base_task}")
|
log_work_event(f"Completed planning phase for: {base_task}")
|
||||||
|
|
@ -739,7 +740,8 @@ def run_task_implementation_agent(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug("Implementation agent completed successfully")
|
logger.debug("Implementation agent completed successfully")
|
||||||
_result = run_agent_with_retry(agent, prompt, run_config)
|
fallback_handler = FallbackHandler(config, tools)
|
||||||
|
_result = run_agent_with_retry(agent, prompt, run_config, fallback_handler)
|
||||||
if _result:
|
if _result:
|
||||||
# Log task implementation completion
|
# Log task implementation completion
|
||||||
log_work_event(f"Completed implementation of task: {task}")
|
log_work_event(f"Completed implementation of task: {task}")
|
||||||
|
|
@ -805,7 +807,7 @@ def _decrement_agent_depth():
|
||||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||||
|
|
||||||
|
|
||||||
def _run_agent_stream(agent, prompt, config):
|
def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict):
|
||||||
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
||||||
logger.debug("Agent output: %s", chunk)
|
logger.debug("Agent output: %s", chunk)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
|
|
@ -840,7 +842,9 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
def run_agent_with_retry(
|
||||||
|
agent, prompt: str, config: dict, fallback_handler: FallbackHandler
|
||||||
|
) -> Optional[str]:
|
||||||
"""Run an agent with retry logic for API errors."""
|
"""Run an agent with retry logic for API errors."""
|
||||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||||
original_handler = _setup_interrupt_handling()
|
original_handler = _setup_interrupt_handling()
|
||||||
|
|
@ -850,7 +854,6 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||||
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
||||||
auto_test = config.get("auto_test", False)
|
auto_test = config.get("auto_test", False)
|
||||||
original_prompt = prompt
|
original_prompt = prompt
|
||||||
fallback_handler = FallbackHandler(config)
|
|
||||||
|
|
||||||
with InterruptibleSection():
|
with InterruptibleSection():
|
||||||
try:
|
try:
|
||||||
|
|
@ -872,8 +875,13 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||||
continue
|
continue
|
||||||
logger.debug("Agent run completed successfully")
|
logger.debug("Agent run completed successfully")
|
||||||
return "Agent run completed successfully"
|
return "Agent run completed successfully"
|
||||||
except (ToolExecutionError, InvalidToolCall) as e:
|
except ToolExecutionError as e:
|
||||||
_handle_tool_execution_error(fallback_handler, agent, e)
|
fallback_response = _handle_tool_execution_error(
|
||||||
|
fallback_handler, agent, e
|
||||||
|
)
|
||||||
|
if fallback_response:
|
||||||
|
prompt = original_prompt + "\n" + fallback_response
|
||||||
|
continue
|
||||||
except (KeyboardInterrupt, AgentInterrupt):
|
except (KeyboardInterrupt, AgentInterrupt):
|
||||||
raise
|
raise
|
||||||
except (
|
except (
|
||||||
|
|
@ -892,6 +900,37 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||||
def _handle_tool_execution_error(
|
def _handle_tool_execution_error(
|
||||||
fallback_handler: FallbackHandler,
|
fallback_handler: FallbackHandler,
|
||||||
agent: CiaynAgent | CompiledGraph,
|
agent: CiaynAgent | CompiledGraph,
|
||||||
error: ToolExecutionError | InvalidToolCall,
|
error: ToolExecutionError,
|
||||||
):
|
):
|
||||||
fallback_handler.handle_failure("Tool execution error", error, agent)
|
logger.debug("Entering _handle_tool_execution_error with error: %s", error)
|
||||||
|
if error.tool_name:
|
||||||
|
failed_tool_call_name = error.tool_name
|
||||||
|
logger.debug(
|
||||||
|
"Extracted failed_tool_call_name from error.tool_name: %s",
|
||||||
|
failed_tool_call_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
import re
|
||||||
|
|
||||||
|
msg = str(error)
|
||||||
|
logger.debug("Error message: %s", msg)
|
||||||
|
match = re.search(r"name=['\"](\w+)['\"]", msg)
|
||||||
|
if match:
|
||||||
|
failed_tool_call_name = match.group(1)
|
||||||
|
logger.debug(
|
||||||
|
"Extracted failed_tool_call_name using regex: %s", failed_tool_call_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
failed_tool_call_name = "Tool execution error"
|
||||||
|
logger.debug(
|
||||||
|
"Defaulting failed_tool_call_name to: %s", failed_tool_call_name
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Calling fallback_handler.handle_failure with failed_tool_call_name: %s",
|
||||||
|
failed_tool_call_name,
|
||||||
|
)
|
||||||
|
fallback_response = fallback_handler.handle_failure(
|
||||||
|
failed_tool_call_name, error, agent
|
||||||
|
)
|
||||||
|
logger.debug("Fallback response received: %s", fallback_response)
|
||||||
|
return fallback_response
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
DEFAULT_RECURSION_LIMIT = 100
|
DEFAULT_RECURSION_LIMIT = 100
|
||||||
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
||||||
DEFAULT_MAX_TOOL_FAILURES = 3
|
DEFAULT_MAX_TOOL_FAILURES = 2
|
||||||
FALLBACK_TOOL_MODEL_LIMIT = 5
|
FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||||
RETRY_FALLBACK_COUNT = 3
|
RETRY_FALLBACK_COUNT = 3
|
||||||
RETRY_FALLBACK_DELAY = 2
|
RETRY_FALLBACK_DELAY = 2
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
||||||
# Import shared console instance
|
# Import shared console instance
|
||||||
from .formatting import console
|
from .formatting import console
|
||||||
|
|
||||||
|
|
@ -33,10 +35,26 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
|
||||||
elif "tools" in chunk and "messages" in chunk["tools"]:
|
elif "tools" in chunk and "messages" in chunk["tools"]:
|
||||||
for msg in chunk["tools"]["messages"]:
|
for msg in chunk["tools"]["messages"]:
|
||||||
if msg.status == "error" and msg.content:
|
if msg.status == "error" and msg.content:
|
||||||
|
err_msg = msg.content.strip()
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
Markdown(msg.content.strip()),
|
Markdown(err_msg),
|
||||||
title="❌ Tool Error",
|
title="❌ Tool Error",
|
||||||
border_style="red bold",
|
border_style="red bold",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
tool_name = getattr(msg, "name", None)
|
||||||
|
raise ToolExecutionError(err_msg, tool_name=tool_name)
|
||||||
|
|
||||||
|
|
||||||
|
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
|
||||||
|
"""
|
||||||
|
Print a message using a Panel with Markdown formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message content to display.
|
||||||
|
title (Optional[str]): An optional title for the panel.
|
||||||
|
border_style (str): Border style for the panel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
console.print(Panel(Markdown(message), title=title, border_style=border_style))
|
||||||
|
|
|
||||||
|
|
@ -17,5 +17,6 @@ class ToolExecutionError(Exception):
|
||||||
This exception is used to distinguish tool execution failures
|
This exception is used to distinguish tool execution failures
|
||||||
from other types of errors in the agent system.
|
from other types of errors in the agent system.
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, message, tool_name=None):
|
||||||
pass
|
super().__init__(message)
|
||||||
|
self.tool_name = tool_name
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,24 @@
|
||||||
|
from typing import Dict
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langgraph.graph.graph import CompiledGraph
|
||||||
|
from langgraph.graph.message import BaseMessage
|
||||||
|
|
||||||
|
from ra_aid.console.output import cpm
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.config import (
|
from ra_aid.config import (
|
||||||
DEFAULT_MAX_TOOL_FAILURES,
|
DEFAULT_MAX_TOOL_FAILURES,
|
||||||
FALLBACK_TOOL_MODEL_LIMIT,
|
FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
RETRY_FALLBACK_COUNT,
|
RETRY_FALLBACK_COUNT,
|
||||||
)
|
)
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.tool_leaderboard import supported_top_tool_models
|
from ra_aid.tool_leaderboard import supported_top_tool_models
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from ra_aid.llm import initialize_llm, validate_provider_env
|
||||||
from rich.panel import Panel
|
|
||||||
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
# from langgraph.graph.message import BaseMessage, BaseMessageChunk
|
||||||
|
# from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -22,18 +33,21 @@ class FallbackHandler:
|
||||||
counters when a tool call succeeds.
|
counters when a tool call succeeds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config, tools):
|
||||||
"""
|
"""
|
||||||
Initialize the FallbackHandler with the given configuration.
|
Initialize the FallbackHandler with the given configuration and tools.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (dict): Configuration dictionary that may include fallback settings.
|
config (dict): Configuration dictionary that may include fallback settings.
|
||||||
|
tools (list): List of available tools.
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.tools: list[BaseTool] = tools
|
||||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||||
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
self.fallback_tool_models = self._load_fallback_tool_models(config)
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.tool_failure_consecutive_failures = 0
|
||||||
self.tool_failure_used_fallbacks = set()
|
self.tool_failure_used_fallbacks = set()
|
||||||
|
self.console = Console()
|
||||||
|
|
||||||
def _load_fallback_tool_models(self, config):
|
def _load_fallback_tool_models(self, config):
|
||||||
"""
|
"""
|
||||||
|
|
@ -49,46 +63,37 @@ class FallbackHandler:
|
||||||
Returns:
|
Returns:
|
||||||
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
|
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
|
||||||
"""
|
"""
|
||||||
fallback_tool_models_config = config.get("fallback_tool_models")
|
supported = []
|
||||||
if fallback_tool_models_config:
|
skipped = []
|
||||||
# Assume comma-separated model names; wrap each in a dict with default type "prompt"
|
for item in supported_top_tool_models:
|
||||||
models = []
|
provider = item.get("provider")
|
||||||
for m in [
|
model_name = item.get("model")
|
||||||
x.strip() for x in fallback_tool_models_config.split(",") if x.strip()
|
if validate_provider_env(provider):
|
||||||
]:
|
supported.append(item)
|
||||||
models.append({"model": m, "type": "prompt"})
|
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
|
||||||
return models
|
break
|
||||||
else:
|
else:
|
||||||
console = Console()
|
skipped.append(model_name)
|
||||||
supported = []
|
final_models = []
|
||||||
skipped = []
|
for item in supported:
|
||||||
for item in supported_top_tool_models:
|
if "type" not in item:
|
||||||
provider = item.get("provider")
|
item["type"] = "prompt"
|
||||||
model_name = item.get("model")
|
item["model"] = item["model"].lower()
|
||||||
if validate_provider_env(provider):
|
final_models.append(item)
|
||||||
supported.append(item)
|
message = "Fallback models selected: " + ", ".join(
|
||||||
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
|
[m["model"] for m in final_models]
|
||||||
break
|
)
|
||||||
else:
|
if skipped:
|
||||||
skipped.append(model_name)
|
message += (
|
||||||
final_models = []
|
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
||||||
for item in supported:
|
+ ", ".join(skipped)
|
||||||
if "type" not in item:
|
|
||||||
item["type"] = "prompt"
|
|
||||||
item["model"] = item["model"].lower()
|
|
||||||
final_models.append(item)
|
|
||||||
message = "Fallback models selected: " + ", ".join(
|
|
||||||
[m["model"] for m in final_models]
|
|
||||||
)
|
)
|
||||||
if skipped:
|
cpm(message, title="Fallback Models")
|
||||||
message += (
|
return final_models
|
||||||
"\nSkipped top tool calling models due to missing provider ENV API keys: "
|
|
||||||
+ ", ".join(skipped)
|
|
||||||
)
|
|
||||||
console.print(Panel(Markdown(message), title="Fallback Models"))
|
|
||||||
return final_models
|
|
||||||
|
|
||||||
def handle_failure(self, code: str, error: Exception, agent):
|
def handle_failure(
|
||||||
|
self, code: str, error: Exception, agent: CiaynAgent | CompiledGraph
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||||
|
|
||||||
|
|
@ -114,7 +119,7 @@ class FallbackHandler:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
|
||||||
)
|
)
|
||||||
self.attempt_fallback(code, logger, agent)
|
return self.attempt_fallback(code, logger, agent)
|
||||||
|
|
||||||
def attempt_fallback(self, code: str, logger, agent):
|
def attempt_fallback(self, code: str, logger, agent):
|
||||||
"""
|
"""
|
||||||
|
|
@ -127,17 +132,13 @@ class FallbackHandler:
|
||||||
"""
|
"""
|
||||||
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
logger.debug(f"_attempt_fallback: initiating fallback for code: {code}")
|
||||||
fallback_model = self.fallback_tool_models[0]
|
fallback_model = self.fallback_tool_models[0]
|
||||||
failed_tool_call_name = code.split("(")[0].strip()
|
failed_tool_call_name = code
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
|
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback to model: {fallback_model['model']} for tool: {failed_tool_call_name}"
|
||||||
)
|
)
|
||||||
Console().print(
|
cpm(
|
||||||
Panel(
|
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}.",
|
||||||
Markdown(
|
title="Fallback Notification",
|
||||||
f"**Tool fallback activated**: Switching to fallback model {fallback_model['model']} for tool {failed_tool_call_name}."
|
|
||||||
),
|
|
||||||
title="Fallback Notification",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if fallback_model.get("type", "prompt").lower() == "fc":
|
if fallback_model.get("type", "prompt").lower() == "fc":
|
||||||
self.attempt_fallback_function(code, logger, agent)
|
self.attempt_fallback_function(code, logger, agent)
|
||||||
|
|
@ -151,6 +152,30 @@ class FallbackHandler:
|
||||||
self.tool_failure_consecutive_failures = 0
|
self.tool_failure_consecutive_failures = 0
|
||||||
self.tool_failure_used_fallbacks.clear()
|
self.tool_failure_used_fallbacks.clear()
|
||||||
|
|
||||||
|
def _find_tool_to_bind(self, agent, failed_tool_call_name):
|
||||||
|
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
|
||||||
|
tool_to_bind = None
|
||||||
|
if hasattr(agent, "tools"):
|
||||||
|
tool_to_bind = next(
|
||||||
|
(t for t in agent.tools if t.func.__name__ == failed_tool_call_name),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if tool_to_bind is None:
|
||||||
|
from ra_aid.tool_configs import get_all_tools
|
||||||
|
|
||||||
|
all_tools = get_all_tools()
|
||||||
|
tool_to_bind = next(
|
||||||
|
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if tool_to_bind is None:
|
||||||
|
available = [t.func.__name__ for t in get_all_tools()]
|
||||||
|
logger.debug(
|
||||||
|
f"Failed to find tool: {failed_tool_call_name}. Available tools: {available}"
|
||||||
|
)
|
||||||
|
raise Exception(f"Tool {failed_tool_call_name} not found in all tools.")
|
||||||
|
return tool_to_bind
|
||||||
|
|
||||||
def attempt_fallback_prompt(self, code: str, logger, agent):
|
def attempt_fallback_prompt(self, code: str, logger, agent):
|
||||||
"""
|
"""
|
||||||
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
|
Attempt a prompt-based fallback by iterating over fallback models and invoking the provided code.
|
||||||
|
|
@ -169,43 +194,41 @@ class FallbackHandler:
|
||||||
Exception: If all prompt-based fallback models fail.
|
Exception: If all prompt-based fallback models fail.
|
||||||
"""
|
"""
|
||||||
logger.debug("Attempting prompt-based fallback using fallback models")
|
logger.debug("Attempting prompt-based fallback using fallback models")
|
||||||
failed_tool_call_name = code.split("(")[0].strip()
|
failed_tool_call_name = code
|
||||||
for fallback_model in self.fallback_tool_models:
|
for fallback_model in self.fallback_tool_models:
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||||
simple_model = initialize_llm(
|
simple_model = initialize_llm(
|
||||||
fallback_model["provider"], fallback_model["model"]
|
fallback_model["provider"], fallback_model["model"]
|
||||||
)
|
)
|
||||||
tool_to_bind = next(
|
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
|
||||||
(
|
|
||||||
t
|
|
||||||
for t in agent.tools
|
|
||||||
if t.func.__name__ == failed_tool_call_name
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if tool_to_bind is None:
|
|
||||||
logger.debug(
|
|
||||||
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
|
|
||||||
)
|
|
||||||
raise Exception(
|
|
||||||
f"Tool {failed_tool_call_name} not found in agent.tools"
|
|
||||||
)
|
|
||||||
binded_model = simple_model.bind_tools(
|
binded_model = simple_model.bind_tools(
|
||||||
[tool_to_bind], tool_choice=failed_tool_call_name
|
[tool_to_bind], tool_choice=failed_tool_call_name
|
||||||
)
|
)
|
||||||
retry_model = binded_model.with_retry(
|
# retry_model = binded_model.with_retry(
|
||||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
# stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
)
|
# )
|
||||||
response = retry_model.invoke(code)
|
response = binded_model.invoke(code)
|
||||||
|
cpm(f"response={response}")
|
||||||
|
|
||||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||||
agent.model = retry_model
|
|
||||||
self.reset_fallback_handler()
|
tool_call = self.base_message_to_tool_call_dict(response)
|
||||||
logger.debug(
|
if tool_call:
|
||||||
"Prompt-based fallback executed successfully with model: "
|
result = self.invoke_prompt_tool_call(tool_call)
|
||||||
+ fallback_model["model"]
|
cpm(f"result={result}")
|
||||||
)
|
logger.debug(
|
||||||
return response
|
"Prompt-based fallback executed successfully with model: "
|
||||||
|
+ fallback_model["model"]
|
||||||
|
)
|
||||||
|
self.reset_fallback_handler()
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
cpm(
|
||||||
|
response.content if hasattr(response, "content") else response,
|
||||||
|
title="Fallback Model Response: " + fallback_model["model"],
|
||||||
|
)
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, KeyboardInterrupt):
|
if isinstance(e, KeyboardInterrupt):
|
||||||
raise
|
raise
|
||||||
|
|
@ -232,28 +255,14 @@ class FallbackHandler:
|
||||||
Exception: If all function-calling fallback models fail.
|
Exception: If all function-calling fallback models fail.
|
||||||
"""
|
"""
|
||||||
logger.debug("Attempting function-calling fallback using fallback models")
|
logger.debug("Attempting function-calling fallback using fallback models")
|
||||||
failed_tool_call_name = code.split("(")[0].strip()
|
failed_tool_call_name = code
|
||||||
for fallback_model in self.fallback_tool_models:
|
for fallback_model in self.fallback_tool_models:
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
logger.debug(f"Trying fallback model: {fallback_model['model']}")
|
||||||
simple_model = initialize_llm(
|
simple_model = initialize_llm(
|
||||||
fallback_model["provider"], fallback_model["model"]
|
fallback_model["provider"], fallback_model["model"]
|
||||||
)
|
)
|
||||||
tool_to_bind = next(
|
tool_to_bind = self._find_tool_to_bind(agent, failed_tool_call_name)
|
||||||
(
|
|
||||||
t
|
|
||||||
for t in agent.tools
|
|
||||||
if t.func.__name__ == failed_tool_call_name
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if tool_to_bind is None:
|
|
||||||
logger.debug(
|
|
||||||
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
|
|
||||||
)
|
|
||||||
raise Exception(
|
|
||||||
f"Tool {failed_tool_call_name} not found in agent.tools"
|
|
||||||
)
|
|
||||||
binded_model = simple_model.bind_tools(
|
binded_model = simple_model.bind_tools(
|
||||||
[tool_to_bind], tool_choice=failed_tool_call_name
|
[tool_to_bind], tool_choice=failed_tool_call_name
|
||||||
)
|
)
|
||||||
|
|
@ -261,13 +270,18 @@ class FallbackHandler:
|
||||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
)
|
)
|
||||||
response = retry_model.invoke(code)
|
response = retry_model.invoke(code)
|
||||||
|
cpm(f"response={response}")
|
||||||
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
self.tool_failure_used_fallbacks.add(fallback_model["model"])
|
||||||
agent.model = retry_model
|
|
||||||
self.reset_fallback_handler()
|
self.reset_fallback_handler()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Function-calling fallback executed successfully with model: "
|
"Function-calling fallback executed successfully with model: "
|
||||||
+ fallback_model["model"]
|
+ fallback_model["model"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cpm(
|
||||||
|
response.content if hasattr(response, "content") else response,
|
||||||
|
title="Fallback Model Response: " + fallback_model["model"],
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, KeyboardInterrupt):
|
if isinstance(e, KeyboardInterrupt):
|
||||||
|
|
@ -276,3 +290,58 @@ class FallbackHandler:
|
||||||
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
|
||||||
)
|
)
|
||||||
raise Exception("All function-calling fallback models failed")
|
raise Exception("All function-calling fallback models failed")
|
||||||
|
|
||||||
|
def invoke_prompt_tool_call(self, tool_call_request: dict):
|
||||||
|
"""
|
||||||
|
Invoke a tool call from a prompt-based fallback response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of invoking the tool.
|
||||||
|
"""
|
||||||
|
tool_name_to_tool = {tool.func.__name__: tool for tool in self.tools}
|
||||||
|
name = tool_call_request["name"]
|
||||||
|
arguments = tool_call_request["arguments"]
|
||||||
|
# return tool_name_to_tool[name].invoke(arguments)
|
||||||
|
# tool_call_dict = {"arguments": arguments}
|
||||||
|
return tool_name_to_tool[name].invoke(arguments)
|
||||||
|
|
||||||
|
def base_message_to_tool_call_dict(self, response: BaseMessage):
|
||||||
|
"""
|
||||||
|
Extracts a tool call dictionary from a fallback response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response object containing tool call data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found,
|
||||||
|
otherwise None.
|
||||||
|
"""
|
||||||
|
tool_calls = None
|
||||||
|
if hasattr(response, "additional_kwargs") and response.additional_kwargs.get(
|
||||||
|
"tool_calls"
|
||||||
|
):
|
||||||
|
tool_calls = response.additional_kwargs.get("tool_calls")
|
||||||
|
elif hasattr(response, "tool_calls"):
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
elif isinstance(response, dict) and response.get("additional_kwargs", {}).get(
|
||||||
|
"tool_calls"
|
||||||
|
):
|
||||||
|
tool_calls = response.get("additional_kwargs").get("tool_calls")
|
||||||
|
if tool_calls:
|
||||||
|
if len(tool_calls) > 1:
|
||||||
|
logger.warning("Multiple tool calls detected, using the first one")
|
||||||
|
tool_call = tool_calls[0]
|
||||||
|
return {
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"type": tool_call["type"],
|
||||||
|
"name": tool_call["function"]["name"],
|
||||||
|
"arguments": (
|
||||||
|
json.loads(tool_call["function"]["arguments"])
|
||||||
|
if isinstance(tool_call["function"]["arguments"], str)
|
||||||
|
else tool_call["function"]["arguments"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ from ra_aid.tools.write_file import write_file_tool
|
||||||
# Read-only tools that don't modify system state
|
# Read-only tools that don't modify system state
|
||||||
def get_read_only_tools(
|
def get_read_only_tools(
|
||||||
human_interaction: bool = False, web_research_enabled: bool = False
|
human_interaction: bool = False, web_research_enabled: bool = False
|
||||||
) -> list:
|
):
|
||||||
"""Get the list of read-only tools, optionally including human interaction tools.
|
"""Get the list of read-only tools, optionally including human interaction tools.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -61,6 +61,21 @@ def get_read_only_tools(
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
def get_all_tools_simple():
|
||||||
|
"""Return a list containing all available tools using existing group methods."""
|
||||||
|
return get_all_tools()
|
||||||
|
|
||||||
|
def get_all_tools():
|
||||||
|
"""Return a list containing all available tools from different groups."""
|
||||||
|
all_tools = []
|
||||||
|
all_tools.extend(get_read_only_tools())
|
||||||
|
all_tools.extend(MODIFICATION_TOOLS)
|
||||||
|
all_tools.extend(EXPERT_TOOLS)
|
||||||
|
all_tools.extend(RESEARCH_TOOLS)
|
||||||
|
all_tools.extend(get_web_research_tools())
|
||||||
|
all_tools.extend(get_chat_tools())
|
||||||
|
return all_tools
|
||||||
|
|
||||||
|
|
||||||
# Define constant tool groups
|
# Define constant tool groups
|
||||||
READ_ONLY_TOOLS = get_read_only_tools()
|
READ_ONLY_TOOLS = get_read_only_tools()
|
||||||
|
|
@ -81,7 +96,7 @@ def get_research_tools(
|
||||||
expert_enabled: bool = True,
|
expert_enabled: bool = True,
|
||||||
human_interaction: bool = False,
|
human_interaction: bool = False,
|
||||||
web_research_enabled: bool = False,
|
web_research_enabled: bool = False,
|
||||||
) -> list:
|
):
|
||||||
"""Get the list of research tools based on mode and whether expert is enabled.
|
"""Get the list of research tools based on mode and whether expert is enabled.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue