feat(agent_utils.py): refactor agent stream handling to improve clarity and maintainability by introducing reset_agent_completion_flags function

feat(fallback_handler.py): enhance fallback handling by allowing RAgents type and improving error handling
fix(config.py): update RAgents type definition to include both CompiledGraph and CiaynAgent for better type safety
refactor(fallback_handler.py): streamline fallback model invocation and response handling for improved readability and functionality
This commit is contained in:
Ariel Frischer 2025-02-12 15:35:31 -08:00
parent 6e8b0f2e42
commit 96b41458a1
3 changed files with 99 additions and 118 deletions

View File

@ -28,7 +28,7 @@ from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents
from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.console.output import print_agent_output from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
@ -807,16 +807,10 @@ def _decrement_agent_depth():
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
def _run_agent_stream(agent: CompiledGraph, prompt: str, config: dict): def reset_agent_completion_flags():
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): _global_memory["plan_completed"] = False
logger.debug("Agent output: %s", chunk) _global_memory["task_completed"] = False
check_interrupt() _global_memory["completion_message"] = ""
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test): def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
@ -842,8 +836,26 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
time.sleep(0.1) time.sleep(0.1)
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
for chunk in agent.stream({"messages": msg_list}, config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
reset_agent_completion_flags()
break
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
reset_agent_completion_flags()
break
def run_agent_with_retry( def run_agent_with_retry(
agent, prompt: str, config: dict, fallback_handler: FallbackHandler agent: RAgents,
prompt: str,
config: dict,
fallback_handler: FallbackHandler,
) -> Optional[str]: ) -> Optional[str]:
"""Run an agent with retry logic for API errors.""" """Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt)) logger.debug("Running agent with prompt length: %d", len(prompt))
@ -854,6 +866,7 @@ def run_agent_with_retry(
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
auto_test = config.get("auto_test", False) auto_test = config.get("auto_test", False)
original_prompt = prompt original_prompt = prompt
msg_list = [HumanMessage(content=prompt)]
with InterruptibleSection(): with InterruptibleSection():
try: try:
@ -862,7 +875,7 @@ def run_agent_with_retry(
logger.debug("Attempt %d/%d", attempt + 1, max_retries) logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt() check_interrupt()
try: try:
_run_agent_stream(agent, prompt, config) _run_agent_stream(agent, msg_list, config)
fallback_handler.reset_fallback_handler() fallback_handler.reset_fallback_handler()
should_break, prompt, auto_test, test_attempts = ( should_break, prompt, auto_test, test_attempts = (
_execute_test_command_wrapper( _execute_test_command_wrapper(
@ -878,7 +891,7 @@ def run_agent_with_retry(
except ToolExecutionError as e: except ToolExecutionError as e:
fallback_response = fallback_handler.handle_failure(e, agent) fallback_response = fallback_handler.handle_failure(e, agent)
if fallback_response: if fallback_response:
prompt = original_prompt + "\n" + str(fallback_response) msg_list.extend(fallback_response)
continue continue
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise

View File

@ -15,3 +15,8 @@ VALID_PROVIDERS = [
"deepseek", "deepseek",
"gemini", "gemini",
] ]
from ra_aid.agents.ciayn_agent import CiaynAgent
from langgraph.graph.graph import CompiledGraph
RAgents = CompiledGraph | CiaynAgent

View File

@ -1,16 +1,16 @@
import json import json
import re import re
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import BaseMessage from langgraph.graph.message import BaseMessage
from langchain_core.messages import SystemMessage, HumanMessage
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import ( from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES, DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT, FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT, RETRY_FALLBACK_COUNT,
RAgents,
) )
from ra_aid.console.output import cpm from ra_aid.console.output import cpm
from ra_aid.exceptions import ToolExecutionError from ra_aid.exceptions import ToolExecutionError
@ -47,11 +47,18 @@ class FallbackHandler:
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES) self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
self.tool_failure_consecutive_failures = 0 self.tool_failure_consecutive_failures = 0
self.failed_messages: list[BaseMessage] = [] self.failed_messages: list[BaseMessage] = []
self.tool_failure_used_fallbacks = set()
self.current_failing_tool_name = "" self.current_failing_tool_name = ""
self.current_tool_to_bind = None self.current_tool_to_bind: None | BaseTool = None
def _load_fallback_tool_models(self, config): cpm(
"Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
title="Fallback Models",
)
def _format_model(self, m: dict) -> str:
return f"{m.get('model', '')} ({m.get('type', 'prompt')})"
def _load_fallback_tool_models(self, _config):
""" """
Load and return fallback tool models based on the provided configuration. Load and return fallback tool models based on the provided configuration.
@ -90,12 +97,9 @@ class FallbackHandler:
"\nSkipped top tool calling models due to missing provider ENV API keys: " "\nSkipped top tool calling models due to missing provider ENV API keys: "
+ ", ".join(skipped) + ", ".join(skipped)
) )
cpm(message, title="Fallback Models")
return final_models return final_models
def handle_failure( def handle_failure(self, error: ToolExecutionError, agent: RAgents):
self, error: ToolExecutionError, agent: CiaynAgent | CompiledGraph
):
""" """
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
@ -137,10 +141,10 @@ class FallbackHandler:
def attempt_fallback(self): def attempt_fallback(self):
""" """
Initiate the fallback process by iterating over all fallback models and triggering the appropriate fallback method. Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call.
Returns: Returns:
The response from a fallback model if any, otherwise None. List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None.
""" """
logger.error( logger.error(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}" f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
@ -150,12 +154,10 @@ class FallbackHandler:
title="Fallback Notification", title="Fallback Notification",
) )
for fallback_model in self.fallback_tool_models: for fallback_model in self.fallback_tool_models:
if fallback_model.get("type", "prompt").lower() == "fc": result_list = self.invoke_fallback(fallback_model)
response = self.attempt_fallback_function(fallback_model) if result_list:
else: msg_list_response = [SystemMessage(str(msg)) for msg in result_list]
response = self.attempt_fallback_prompt(fallback_model) return msg_list_response
if response:
return response
cpm("All fallback models have failed", title="Fallback Failed") cpm("All fallback models have failed", title="Fallback Failed")
return None return None
@ -165,8 +167,9 @@ class FallbackHandler:
""" """
self.tool_failure_consecutive_failures = 0 self.tool_failure_consecutive_failures = 0
self.failed_messages.clear() self.failed_messages.clear()
self.tool_failure_used_fallbacks.clear()
self.fallback_tool_models = self._load_fallback_tool_models(self.config) self.fallback_tool_models = self._load_fallback_tool_models(self.config)
self.current_failing_tool_name = ""
self.current_tool_to_bind = None
def _reset_on_new_failure(self, failed_tool_call_name): def _reset_on_new_failure(self, failed_tool_call_name):
if ( if (
@ -218,9 +221,21 @@ class FallbackHandler:
) )
return tool_to_bind return tool_to_bind
def attempt_fallback_prompt(self, fallback_model): def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model):
if fallback_model.get("type", "prompt").lower() == "fc":
# Force tool calling with tool_choice param.
bound_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
else:
# Do not force tool calling (Prompt method)
bound_model = simple_model.bind_tools([self.current_tool_to_bind])
return bound_model
def invoke_fallback(self, fallback_model):
""" """
Attempt a prompt-based fallback by invoking the current failing tool with the given fallback model. Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model.
Args: Args:
fallback_model (dict): The fallback model to use. fallback_model (dict): The fallback model to use.
@ -229,88 +244,33 @@ class FallbackHandler:
The response from the fallback model invocation, or None if failed. The response from the fallback model invocation, or None if failed.
""" """
try: try:
logger.debug(f"Trying fallback model: {fallback_model['model']}") logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}")
simple_model = initialize_llm( simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"] fallback_model["provider"], fallback_model["model"]
) )
binded_model = simple_model.bind_tools(
[self.current_tool_to_bind], bound_model = self._bind_tool_model(simple_model, fallback_model)
tool_choice=self.current_failing_tool_name,
) retry_model = bound_model.with_retry(
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT stop_after_attempt=RETRY_FALLBACK_COUNT
) )
# msg_list = []
msg_list = self.construct_prompt_msg_list() msg_list = self.construct_prompt_msg_list()
# response = retry_model.invoke(self.current_failing_tool_name)
response = retry_model.invoke(msg_list) response = retry_model.invoke(msg_list)
cpm(f"response={response}") logger.debug(f"raw llm response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
tool_call = self.base_message_to_tool_call_dict(response) tool_call = self.base_message_to_tool_call_dict(response)
if tool_call:
result = self.invoke_prompt_tool_call(tool_call) tool_call_result = self.invoke_prompt_tool_call(tool_call)
cpm(f"result={result}") cpm(str(tool_call_result), title="Fallback Tool Call Result")
logger.debug( logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}")
"Prompt-based fallback executed successfully with model: "
+ fallback_model["model"] self.reset_fallback_handler()
) return [response, tool_call_result]
self.reset_fallback_handler()
return result
else:
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
except Exception as e: except Exception as e:
if isinstance(e, KeyboardInterrupt): if isinstance(e, KeyboardInterrupt):
raise raise
logger.error( logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}")
f"Prompt-based fallback with model {fallback_model['model']} failed: {e}"
)
return None
def attempt_fallback_function(self, fallback_model):
"""
Attempt a function-calling fallback by invoking the current failing tool with the given fallback model.
Args:
fallback_model (dict): The fallback model to use.
Returns:
The response from the fallback model invocation, or None if failed.
"""
try:
logger.debug(f"Trying fallback model: {fallback_model['model']}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
binded_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
response = retry_model.invoke(self.current_failing_tool_name)
cpm(f"response={response}")
self.tool_failure_used_fallbacks.add(fallback_model["model"])
logger.debug(
"Function-calling fallback executed successfully with model: "
+ fallback_model["model"]
)
cpm(
response.content if hasattr(response, "content") else response,
title="Fallback Model Response: " + fallback_model["model"],
)
return response
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Function-calling fallback with model {fallback_model['model']} failed: {e}"
)
return None return None
def construct_prompt_msg_list(self): def construct_prompt_msg_list(self):
@ -367,19 +327,22 @@ class FallbackHandler:
otherwise None. otherwise None.
""" """
tool_calls = self.get_tool_calls(response) tool_calls = self.get_tool_calls(response)
if tool_calls:
if len(tool_calls) > 1: if not tool_calls:
logger.warning("Multiple tool calls detected, using the first one") raise Exception(
tool_call = tool_calls[0] f"Could not extract tool_call_dict from response: {response}"
return { )
"id": tool_call["id"],
"type": tool_call["type"], if len(tool_calls) > 1:
"name": tool_call["function"]["name"], logger.warning("Multiple tool calls detected, using the first one")
"arguments": self._parse_tool_arguments(
tool_call["function"]["arguments"] tool_call = tool_calls[0]
), return {
} "id": tool_call["id"],
return None "type": tool_call["type"],
"name": tool_call["function"]["name"],
"arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]),
}
def _parse_tool_arguments(self, tool_arguments): def _parse_tool_arguments(self, tool_arguments):
""" """