refactor(agent_utils.py): refactor run_agent_with_retry function for better readability and maintainability by extracting helper functions

feat(agent_utils.py): add new helper functions for handling API errors and managing interrupt signals
fix(agent_utils.py): improve error handling in tool execution and retry logic
feat(fallback_handler.py): enhance fallback handling by binding tools correctly during retries
test(tests): add unit tests for new helper functions and refactored logic in agent_utils.py
This commit is contained in:
Ariel Frischer 2025-02-11 12:16:04 -08:00
parent de489584e5
commit 1388067769
4 changed files with 223 additions and 100 deletions

View File

@ -9,10 +9,16 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence from typing import Any, Dict, List, Literal, Optional, Sequence
from langgraph.graph.graph import CompiledGraph
import litellm import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, trim_messages from langchain_core.messages import (
BaseMessage,
HumanMessage,
InvalidToolCall,
trim_messages,
)
from langchain_core.tools import tool from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
@ -26,7 +32,8 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.console.output import print_agent_output from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import AgentInterrupt from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.project_info import ( from ra_aid.project_info import (
@ -238,7 +245,7 @@ def create_agent(
*, *,
checkpointer: Any = None, checkpointer: Any = None,
agent_type: str = "default", agent_type: str = "default",
) -> Any: ):
"""Create a react agent with the given configuration. """Create a react agent with the given configuration.
Args: Args:
@ -775,61 +782,98 @@ def check_interrupt():
raise AgentInterrupt("Interrupt requested") raise AgentInterrupt("Interrupt requested")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: # New helper functions for run_agent_with_retry refactoring
"""Run an agent with retry logic for API errors.""" def _setup_interrupt_handling():
logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = None
if threading.current_thread() is threading.main_thread(): if threading.current_thread() is threading.main_thread():
original_handler = signal.getsignal(signal.SIGINT) original_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, _request_interrupt) signal.signal(signal.SIGINT, _request_interrupt)
return original_handler
return None
def _restore_interrupt_handling(original_handler):
if original_handler and threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, original_handler)
def _increment_agent_depth():
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
def _decrement_agent_depth():
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
def _run_agent_stream(agent, prompt, config):
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
return execute_test_command(config, original_prompt, test_attempts, auto_test)
def _handle_api_error(e, attempt, max_retries, base_delay):
if isinstance(e, ValueError):
error_str = str(e).lower()
if "code" not in error_str or "429" not in error_str:
raise e
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = _setup_interrupt_handling()
max_retries = 20 max_retries = 20
base_delay = 1 base_delay = 1
test_attempts = 0 test_attempts = 0
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
auto_test = config.get("auto_test", False) auto_test = config.get("auto_test", False)
original_prompt = prompt original_prompt = prompt
fallback_handler = FallbackHandler(config)
with InterruptibleSection(): with InterruptibleSection():
try: try:
# Track agent execution depth _increment_agent_depth()
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
for attempt in range(max_retries): for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries) logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt() check_interrupt()
try: try:
for chunk in agent.stream( _run_agent_stream(agent, prompt, config)
{"messages": [HumanMessage(content=prompt)]}, config fallback_handler.reset_fallback_handler()
):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"]:
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
if _global_memory["task_completed"]:
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
# Execute test command if configured
should_break, prompt, auto_test, test_attempts = ( should_break, prompt, auto_test, test_attempts = (
execute_test_command( _execute_test_command_wrapper(
config, original_prompt, test_attempts, auto_test original_prompt, config, test_attempts, auto_test
) )
) )
if should_break: if should_break:
break break
if prompt != original_prompt: if prompt != original_prompt:
continue continue
logger.debug("Agent run completed successfully") logger.debug("Agent run completed successfully")
return "Agent run completed successfully" return "Agent run completed successfully"
except (ToolExecutionError, InvalidToolCall) as e:
_handle_tool_execution_error(fallback_handler, agent, e)
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise
except ( except (
@ -839,35 +883,15 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
APIError, APIError,
ValueError, ValueError,
) as e: ) as e:
if isinstance(e, ValueError): _handle_api_error(e, attempt, max_retries, base_delay)
error_str = str(e).lower()
if "code" not in error_str or "429" not in error_str:
raise # Re-raise ValueError if it's not a Lambda 429
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(
f"Max retries ({max_retries}) exceeded. Last error: {e}"
)
logger.warning(
"API error (attempt %d/%d): %s",
attempt + 1,
max_retries,
str(e),
)
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
finally: finally:
# Reset depth tracking _decrement_agent_depth()
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 _restore_interrupt_handling(original_handler)
if (
original_handler def _handle_tool_execution_error(
and threading.current_thread() is threading.main_thread() fallback_handler: FallbackHandler,
): agent: CiaynAgent | CompiledGraph,
signal.signal(signal.SIGINT, original_handler) error: ToolExecutionError | InvalidToolCall,
):
fallback_handler.handle_failure("Tool execution error", error, agent)

View File

@ -4,7 +4,6 @@ from typing import Any, Dict, Generator, List, Optional, Union
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.exceptions import ToolExecutionError from ra_aid.exceptions import ToolExecutionError
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
@ -84,13 +83,12 @@ class CiaynAgent:
tools: List of tools available to the agent tools: List of tools available to the agent
max_history_messages: Maximum number of messages to keep in chat history max_history_messages: Maximum number of messages to keep in chat history
max_tokens: Maximum number of tokens allowed in message history (None for no limit) max_tokens: Maximum number of tokens allowed in message history (None for no limit)
config: Optional configuration dictionary for fallback settings config: Optional configuration dictionary
""" """
if config is None: if config is None:
config = {} config = {}
self.config = config self.config = config
self.provider = config.get("provider", "openai") self.provider = config.get("provider", "openai")
self.fallback_handler = FallbackHandler(config)
self.model = model self.model = model
self.tools = tools self.tools = tools
@ -232,39 +230,29 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
return base_prompt return base_prompt
def _execute_tool(self, code: str) -> str: def _execute_tool(self, code: str) -> str:
"""Execute a tool call with retry and fallback logic and return its result.""" """Execute a tool call and return its result."""
max_retries = 3 globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
retries = 0
last_error = None
while retries < max_retries:
try:
logger.debug(
f"_execute_tool: attempt {retries+1}, original code: {code}"
)
code = code.strip()
if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions)
code = _extract_tool_call(code, functions_list)
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
logger.debug(f"_execute_tool: evaluating code: {code}")
result = eval(code, globals_dict)
logger.debug(
f"_execute_tool: tool executed successfully with result: {result}"
)
self.fallback_handler.reset_fallback_handler()
return result
except Exception as e:
logger.debug(f"_execute_tool: exception caught: {e}")
self._handle_tool_failure(code, e)
last_error = e
retries += 1
logger.debug(f"_execute_tool: retrying, new attempt count: {retries}")
raise ToolExecutionError(
f"Error executing code after {max_retries} attempts: {str(last_error)}"
)
def _handle_tool_failure(self, code: str, error: Exception) -> None: try:
self.fallback_handler.handle_failure(code, error, logger, self) code = code.strip()
logger.debug(f"_execute_tool: stripped code: {code}")
# if the eval fails, try to extract it via a model call
if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions)
logger.debug(f"_execute_tool: code before extraction: {code}")
code = _extract_tool_call(code, functions_list)
logger.debug(f"_execute_tool: code after extraction: {code}")
logger.debug(
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
)
result = eval(code.strip(), globals_dict)
logger.debug(f"_execute_tool: result: {result}")
return result
except Exception as e:
error_msg = f"Error executing code: {str(e)}"
raise ToolExecutionError(error_msg)
def _create_agent_chunk(self, content: str) -> Dict[str, Any]: def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
"""Create an agent chunk in the format expected by print_agent_output.""" """Create an agent chunk in the format expected by print_agent_output."""

View File

@ -9,6 +9,8 @@ from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
logger = get_logger(__name__)
class FallbackHandler: class FallbackHandler:
""" """
@ -73,6 +75,7 @@ class FallbackHandler:
for item in supported: for item in supported:
if "type" not in item: if "type" not in item:
item["type"] = "prompt" item["type"] = "prompt"
item["model"] = item["model"].lower()
final_models.append(item) final_models.append(item)
message = "Fallback models selected: " + ", ".join( message = "Fallback models selected: " + ", ".join(
[m["model"] for m in final_models] [m["model"] for m in final_models]
@ -85,7 +88,7 @@ class FallbackHandler:
console.print(Panel(Markdown(message), title="Fallback Models")) console.print(Panel(Markdown(message), title="Fallback Models"))
return final_models return final_models
def handle_failure(self, code: str, error: Exception, logger, agent): def handle_failure(self, code: str, error: Exception, agent):
""" """
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
@ -173,8 +176,23 @@ class FallbackHandler:
simple_model = initialize_llm( simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"] fallback_model["provider"], fallback_model["model"]
) )
tool_to_bind = next(
(
t
for t in agent.tools
if t.func.__name__ == failed_tool_call_name
),
None,
)
if tool_to_bind is None:
logger.debug(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
)
raise Exception(
f"Tool {failed_tool_call_name} not found in agent.tools"
)
binded_model = simple_model.bind_tools( binded_model = simple_model.bind_tools(
agent.tools, tool_choice=failed_tool_call_name [tool_to_bind], tool_choice=failed_tool_call_name
) )
retry_model = binded_model.with_retry( retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT stop_after_attempt=RETRY_FALLBACK_COUNT
@ -221,8 +239,23 @@ class FallbackHandler:
simple_model = initialize_llm( simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"] fallback_model["provider"], fallback_model["model"]
) )
tool_to_bind = next(
(
t
for t in agent.tools
if t.func.__name__ == failed_tool_call_name
),
None,
)
if tool_to_bind is None:
logger.debug(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
)
raise Exception(
f"Tool {failed_tool_call_name} not found in agent.tools"
)
binded_model = simple_model.bind_tools( binded_model = simple_model.bind_tools(
agent.tools, tool_choice=failed_tool_call_name [tool_to_bind], tool_choice=failed_tool_call_name
) )
retry_model = binded_model.with_retry( retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT stop_after_attempt=RETRY_FALLBACK_COUNT

View File

@ -275,3 +275,81 @@ def test_get_model_token_limit_planner(mock_memory):
mock_get_info.return_value = {"max_input_tokens": 120000} mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner") token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000 assert token_limit == 120000
# New tests for private helper methods in agent_utils.py
def test_setup_and_restore_interrupt_handling():
import signal, threading
from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt
original_handler = signal.getsignal(signal.SIGINT)
handler = _setup_interrupt_handling()
# Verify the SIGINT handler is set to _request_interrupt
assert signal.getsignal(signal.SIGINT) == _request_interrupt
_restore_interrupt_handling(handler)
# Verify the SIGINT handler is restored to the original
assert signal.getsignal(signal.SIGINT) == original_handler
def test_increment_and_decrement_agent_depth():
from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory
_global_memory["agent_depth"] = 10
_increment_agent_depth()
assert _global_memory["agent_depth"] == 11
_decrement_agent_depth()
assert _global_memory["agent_depth"] == 10
def test_run_agent_stream(monkeypatch):
from ra_aid.agent_utils import _run_agent_stream, _global_memory
# Create a dummy agent that yields one chunk
class DummyAgent:
def stream(self, msg, cfg):
yield {"content": "chunk1"}
dummy_agent = DummyAgent()
# Set flags so that _run_agent_stream will reset them
_global_memory["plan_completed"] = True
_global_memory["task_completed"] = True
_global_memory["completion_message"] = "existing"
call_flag = {"called": False}
def fake_print_agent_output(chunk):
call_flag["called"] = True
monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output)
_run_agent_stream(dummy_agent, "dummy prompt", {})
assert call_flag["called"]
assert _global_memory["plan_completed"] is False
assert _global_memory["task_completed"] is False
assert _global_memory["completion_message"] == ""
def test_execute_test_command_wrapper(monkeypatch):
from ra_aid.agent_utils import _execute_test_command_wrapper
# Patch execute_test_command to return a testable tuple
def fake_execute(config, orig, tests, auto):
return (True, "new prompt", auto, tests + 1)
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
result = _execute_test_command_wrapper("orig", {}, 0, False)
assert result == (True, "new prompt", False, 1)
def test_handle_api_error_valueerror():
from ra_aid.agent_utils import _handle_api_error
import pytest
# ValueError not containing "code" or "429" should be re-raised
with pytest.raises(ValueError):
_handle_api_error(ValueError("some error"), 0, 5, 1)
def test_handle_api_error_max_retries():
from ra_aid.agent_utils import _handle_api_error
import pytest
# When attempt reaches max retries, a RuntimeError should be raised
with pytest.raises(RuntimeError):
_handle_api_error(Exception("error code 429"), 4, 5, 1)
def test_handle_api_error_retry(monkeypatch):
from ra_aid.agent_utils import _handle_api_error
import time
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
fake_time = [0]
def fake_monotonic():
fake_time[0] += 0.5
return fake_time[0]
monkeypatch.setattr(time, "monotonic", fake_monotonic)
monkeypatch.setattr(time, "sleep", lambda s: None)
# Should not raise error when attempt is lower than max retries
_handle_api_error(Exception("error code 429"), 0, 5, 1)