feat(agent_utils.py): refactor agent stream handling to improve clarity and maintainability by introducing reset_agent_completion_flags function

feat(fallback_handler.py): enhance fallback handling by allowing RAgents type and improving error handling
fix(config.py): update RAgents type definition to include both CompiledGraph and CiaynAgent for better type safety
refactor(fallback_handler.py): streamline fallback model invocation and response handling for improved readability and functionality
This commit is contained in:
Ariel Frischer 2025-02-12 15:35:31 -08:00
parent 6e8b0f2e42
commit 96b41458a1
3 changed files with 99 additions and 118 deletions

View File

@ -28,7 +28,7 @@ from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents
from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
@ -807,16 +807,10 @@ def _decrement_agent_depth():
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict):
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
def reset_agent_completion_flags():
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
@ -842,8 +836,26 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
time.sleep(0.1)
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
for chunk in agent.stream({"messages": msg_list}, config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
reset_agent_completion_flags()
break
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
reset_agent_completion_flags()
break
def run_agent_with_retry(
agent, prompt: str, config: dict, fallback_handler: FallbackHandler
agent: RAgents,
prompt: str,
config: dict,
fallback_handler: FallbackHandler,
) -> Optional[str]:
"""Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt))
@ -854,6 +866,7 @@ def run_agent_with_retry(
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
auto_test = config.get("auto_test", False)
original_prompt = prompt
msg_list = [HumanMessage(content=prompt)]
with InterruptibleSection():
try:
@ -862,7 +875,7 @@ def run_agent_with_retry(
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt()
try:
_run_agent_stream(agent, prompt, config)
_run_agent_stream(agent, msg_list, config)
fallback_handler.reset_fallback_handler()
should_break, prompt, auto_test, test_attempts = (
_execute_test_command_wrapper(
@ -878,7 +891,7 @@ def run_agent_with_retry(
except ToolExecutionError as e:
fallback_response = fallback_handler.handle_failure(e, agent)
if fallback_response:
prompt = original_prompt + "\n" + str(fallback_response)
msg_list.extend(fallback_response)
continue
except (KeyboardInterrupt, AgentInterrupt):
raise

View File

@ -15,3 +15,8 @@ VALID_PROVIDERS = [
"deepseek",
"gemini",
]
from ra_aid.agents.ciayn_agent import CiaynAgent
from langgraph.graph.graph import CompiledGraph
RAgents = CompiledGraph | CiaynAgent

View File

@ -1,16 +1,16 @@
import json
import re
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import BaseMessage
from langchain_core.messages import SystemMessage, HumanMessage
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
RAgents,
)
from ra_aid.console.output import cpm
from ra_aid.exceptions import ToolExecutionError
@ -47,11 +47,18 @@ class FallbackHandler:
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
self.tool_failure_consecutive_failures = 0
self.failed_messages: list[BaseMessage] = []
self.tool_failure_used_fallbacks = set()
self.current_failing_tool_name = ""
self.current_tool_to_bind = None
self.current_tool_to_bind: None | BaseTool = None
def _load_fallback_tool_models(self, config):
cpm(
"Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
title="Fallback Models",
)
def _format_model(self, m: dict) -> str:
return f"{m.get('model', '')} ({m.get('type', 'prompt')})"
def _load_fallback_tool_models(self, _config):
"""
Load and return fallback tool models based on the provided configuration.
@ -90,12 +97,9 @@ class FallbackHandler:
"\nSkipped top tool calling models due to missing provider ENV API keys: "
+ ", ".join(skipped)
)
cpm(message, title="Fallback Models")
return final_models
def handle_failure(
self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph
):
def handle_failure(self, error: ToolExecutionError, agent: RAgents):
"""
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
@ -137,10 +141,10 @@ class FallbackHandler:
def attempt_fallback(self):
"""
Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method.
Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call.
Returns:
The response from a fallback model if any, otherwise None.
List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None.
"""
logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
@ -150,12 +154,10 @@ class FallbackHandler:
title="Fallback Notification",
)
for fallback_model in self.fallback_tool_models:
if fallback_model.get("type", "prompt").lower() == "fc":
response = self.attempt_fallback_function(fallback_model)
else:
response = self.attempt_fallback_prompt(fallback_model)
if response:
return response
result_list = self.invoke_fallback(fallback_model)
if result_list:
msg_list_response = [SystemMessage(str(msg)) for msg in result_list]
return msg_list_response
cpm("All fallback models have failed", title="Fallback Failed")
return None
@ -165,8 +167,9 @@ class FallbackHandler:
"""
self.tool_failure_consecutive_failures = 0
self.failed_messages.clear()
self.tool_failure_used_fallbacks.clear()
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
self.current_failing_tool_name = ""
self.current_tool_to_bind = None
def _reset_on_new_failure(self, failed_tool_call_name):
if (
@ -218,9 +221,21 @@ class FallbackHandler:
)
return tool_to_bind
def attempt_fallback_prompt(self, fallback_model):
def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model):
if fallback_model.get("type", "prompt").lower() == "fc":
# Force tool calling with tool_choice param.
bound_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
else:
# Do not force tool calling (Prompt method)
bound_model = simple_model.bind_tools([self.current_tool_to_bind])
return bound_model
def invoke_fallback(self, fallback_model):
"""
Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model.
Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model.
Args:
fallback_model (dict): The fallback model to use.
@ -229,88 +244,33 @@ class FallbackHandler:
The response from the fallback model invocation, or None if failed.
"""
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
binded_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
retry_model = binded_model.with_retry(
bound_model = self._bind_tool_model(simple_model, fallback_model)
retry_model = bound_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
# msg_list = []
msg_list = self.construct_prompt_msg_list()
# response = retry_model.invoke(self.current_failing_tool_name)
response = retry_model.invoke(msg_list)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
logger.debug(f"raw llm response={response}")
tool_call = self.base_message_to_tool_call_dict(response)
if tool_call:
result = self.invoke_prompt_tool_call(tool_call)
cpm(f"result={result}")
logger.debug(
"Prompt-based fallback executed successfully with model: "
+ fallback_model["model"]
)
tool_call_result = self.invoke_prompt_tool_call(tool_call)
cpm(str(tool_call_result), title="Fallback Tool Call Result")
logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}")
self.reset_fallback_handler()
return result
else:
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
return [response, tool_call_result]
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
)
return None
def attempt_fallback_function(self, fallback_model):
"""
Attempt a function-calling fallback by invoking the current failing tool with the given fallback model.
Args:
fallback_model (dict): The fallback model to use.
Returns:
The response from the fallback model invocation, or None if failed.
"""
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
binded_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(self.current_failing_tool_name)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
logger.debug(
"Function-calling fallback executed successfully with model: "
+ fallback_model["model"]
)
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
)
logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}")
return None
def construct_prompt_msg_list(self):
@ -367,19 +327,22 @@ class FallbackHandler:
otherwise None.
"""
tool_calls = self.get_tool_calls(response)
if tool_calls:
if not tool_calls:
raise Exception(
f"Could not extract tool_call_dict from response: {response}"
)
if len(tool_calls) > 1:
logger.warning("Multiple tool calls detected, using the first one")
tool_call = tool_calls[0]
return {
"id": tool_call["id"],
"type": tool_call["type"],
"name": tool_call["function"]["name"],
"arguments": self._parse_tool_arguments(
tool_call["function"]["arguments"]
),
"arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]),
}
return None
def _parse_tool_arguments(self, tool_arguments):
"""