refactor(agent_utils.py): refactor run_agent_with_retry function for better readability and maintainability by extracting helper functions

feat(agent_utils.py): add new helper functions for handling API errors and managing interrupt signals
fix(agent_utils.py): improve error handling in tool execution and retry logic
feat(fallback_handler.py): enhance fallback handling by binding tools correctly during retries
test(tests): add unit tests for new helper functions and refactored logic in agent_utils.py
This commit is contained in:
Ariel Frischer 2025-02-11 12:16:04 -08:00
parent de489584e5
commit 1388067769
4 changed files with 223 additions and 100 deletions

View File

@ -9,10 +9,16 @@ import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence
from langgraph.graph.graph import CompiledGraph
import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, trim_messages
from langchain_core.messages import (
BaseMessage,
HumanMessage,
InvalidToolCall,
trim_messages,
)
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
@ -26,7 +32,8 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import AgentInterrupt
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.project_info import (
@ -238,7 +245,7 @@ def create_agent(
*,
checkpointer: Any = None,
agent_type: str = "default",
) -> Any:
):
"""Create a react agent with the given configuration.
Args:
@ -775,61 +782,98 @@ def check_interrupt():
raise AgentInterrupt("Interrupt requested")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = None
# New helper functions for run_agent_with_retry refactoring
def _setup_interrupt_handling():
if threading.current_thread() is threading.main_thread():
original_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, _request_interrupt)
return original_handler
return None
def _restore_interrupt_handling(original_handler):
if original_handler and threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, original_handler)
def _increment_agent_depth():
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
def _decrement_agent_depth():
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
def _run_agent_stream(agent, prompt, config):
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
return execute_test_command(config, original_prompt, test_attempts, auto_test)
def _handle_api_error(e, attempt, max_retries, base_delay):
if isinstance(e, ValueError):
error_str = str(e).lower()
if "code" not in error_str or "429" not in error_str:
raise e
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = _setup_interrupt_handling()
max_retries = 20
base_delay = 1
test_attempts = 0
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
auto_test = config.get("auto_test", False)
original_prompt = prompt
fallback_handler = FallbackHandler(config)
with InterruptibleSection():
try:
# Track agent execution depth
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
_increment_agent_depth()
for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt()
try:
for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]}, config
):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"]:
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
if _global_memory["task_completed"]:
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
# Execute test command if configured
_run_agent_stream(agent, prompt, config)
fallback_handler.reset_fallback_handler()
should_break, prompt, auto_test, test_attempts = (
execute_test_command(
config, original_prompt, test_attempts, auto_test
_execute_test_command_wrapper(
original_prompt, config, test_attempts, auto_test
)
)
if should_break:
break
if prompt != original_prompt:
continue
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except (ToolExecutionError, InvalidToolCall) as e:
_handle_tool_execution_error(fallback_handler, agent, e)
except (KeyboardInterrupt, AgentInterrupt):
raise
except (
@ -839,35 +883,15 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
APIError,
ValueError,
) as e:
if isinstance(e, ValueError):
error_str = str(e).lower()
if "code" not in error_str or "429" not in error_str:
raise # Re-raise ValueError if it's not a Lambda 429
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(
f"Max retries ({max_retries}) exceeded. Last error: {e}"
)
logger.warning(
"API error (attempt %d/%d): %s",
attempt + 1,
max_retries,
str(e),
)
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
_handle_api_error(e, attempt, max_retries, base_delay)
finally:
# Reset depth tracking
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
_decrement_agent_depth()
_restore_interrupt_handling(original_handler)
if (
original_handler
and threading.current_thread() is threading.main_thread()
):
signal.signal(signal.SIGINT, original_handler)
def _handle_tool_execution_error(
fallback_handler: FallbackHandler,
agent: CiaynAgent | CompiledGraph,
error: ToolExecutionError | InvalidToolCall,
):
fallback_handler.handle_failure("Tool execution error", error, agent)

View File

@ -4,7 +4,6 @@ from typing import Any, Dict, Generator, List, Optional, Union
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.exceptions import ToolExecutionError
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
@ -84,13 +83,12 @@ class CiaynAgent:
tools: List of tools available to the agent
max_history_messages: Maximum number of messages to keep in chat history
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
config: Optional configuration dictionary for fallback settings
config: Optional configuration dictionary
"""
if config is None:
config = {}
self.config = config
self.provider = config.get("provider", "openai")
self.fallback_handler = FallbackHandler(config)
self.model = model
self.tools = tools
@ -232,39 +230,29 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
return base_prompt
def _execute_tool(self, code: str) -> str:
"""Execute a tool call with retry and fallback logic and return its result."""
max_retries = 3
retries = 0
last_error = None
while retries < max_retries:
try:
logger.debug(
f"_execute_tool: attempt {retries+1}, original code: {code}"
)
code = code.strip()
if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions)
code = _extract_tool_call(code, functions_list)
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
logger.debug(f"_execute_tool: evaluating code: {code}")
result = eval(code, globals_dict)
logger.debug(
f"_execute_tool: tool executed successfully with result: {result}"
)
self.fallback_handler.reset_fallback_handler()
return result
except Exception as e:
logger.debug(f"_execute_tool: exception caught: {e}")
self._handle_tool_failure(code, e)
last_error = e
retries += 1
logger.debug(f"_execute_tool: retrying, new attempt count: {retries}")
raise ToolExecutionError(
f"Error executing code after {max_retries} attempts: {str(last_error)}"
)
"""Execute a tool call and return its result."""
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
def _handle_tool_failure(self, code: str, error: Exception) -> None:
self.fallback_handler.handle_failure(code, error, logger, self)
try:
code = code.strip()
logger.debug(f"_execute_tool: stripped code: {code}")
# if the eval fails, try to extract it via a model call
if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions)
logger.debug(f"_execute_tool: code before extraction: {code}")
code = _extract_tool_call(code, functions_list)
logger.debug(f"_execute_tool: code after extraction: {code}")
logger.debug(
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
)
result = eval(code.strip(), globals_dict)
logger.debug(f"_execute_tool: result: {result}")
return result
except Exception as e:
error_msg = f"Error executing code: {str(e)}"
raise ToolExecutionError(error_msg)
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
"""Create an agent chunk in the format expected by print_agent_output."""

View File

@ -9,6 +9,8 @@ from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
logger = get_logger(__name__)
class FallbackHandler:
"""
@ -73,6 +75,7 @@ class FallbackHandler:
for item in supported:
if "type" not in item:
item["type"] = "prompt"
item["model"] = item["model"].lower()
final_models.append(item)
message = "Fallback models selected: " + ", ".join(
[m["model"] for m in final_models]
@ -85,7 +88,7 @@ class FallbackHandler:
console.print(Panel(Markdown(message), title="Fallback Models"))
return final_models
def handle_failure(self, code: str, error: Exception, logger, agent):
def handle_failure(self, code: str, error: Exception, agent):
"""
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
@ -173,8 +176,23 @@ class FallbackHandler:
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
tool_to_bind = next(
(
t
for t in agent.tools
if t.func.__name__ == failed_tool_call_name
),
None,
)
if tool_to_bind is None:
logger.debug(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
)
raise Exception(
f"Tool {failed_tool_call_name} not found in agent.tools"
)
binded_model = simple_model.bind_tools(
agent.tools, tool_choice=failed_tool_call_name
[tool_to_bind], tool_choice=failed_tool_call_name
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
@ -221,8 +239,23 @@ class FallbackHandler:
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
tool_to_bind = next(
(
t
for t in agent.tools
if t.func.__name__ == failed_tool_call_name
),
None,
)
if tool_to_bind is None:
logger.debug(
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
)
raise Exception(
f"Tool {failed_tool_call_name} not found in agent.tools"
)
binded_model = simple_model.bind_tools(
agent.tools, tool_choice=failed_tool_call_name
[tool_to_bind], tool_choice=failed_tool_call_name
)
retry_model = binded_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT

View File

@ -275,3 +275,81 @@ def test_get_model_token_limit_planner(mock_memory):
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000
# New tests for private helper methods in agent_utils.py
def test_setup_and_restore_interrupt_handling():
import signal, threading
from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt
original_handler = signal.getsignal(signal.SIGINT)
handler = _setup_interrupt_handling()
# Verify the SIGINT handler is set to _request_interrupt
assert signal.getsignal(signal.SIGINT) == _request_interrupt
_restore_interrupt_handling(handler)
# Verify the SIGINT handler is restored to the original
assert signal.getsignal(signal.SIGINT) == original_handler
def test_increment_and_decrement_agent_depth():
from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory
_global_memory["agent_depth"] = 10
_increment_agent_depth()
assert _global_memory["agent_depth"] == 11
_decrement_agent_depth()
assert _global_memory["agent_depth"] == 10
def test_run_agent_stream(monkeypatch):
from ra_aid.agent_utils import _run_agent_stream, _global_memory
# Create a dummy agent that yields one chunk
class DummyAgent:
def stream(self, msg, cfg):
yield {"content": "chunk1"}
dummy_agent = DummyAgent()
# Set flags so that _run_agent_stream will reset them
_global_memory["plan_completed"] = True
_global_memory["task_completed"] = True
_global_memory["completion_message"] = "existing"
call_flag = {"called": False}
def fake_print_agent_output(chunk):
call_flag["called"] = True
monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output)
_run_agent_stream(dummy_agent, "dummy prompt", {})
assert call_flag["called"]
assert _global_memory["plan_completed"] is False
assert _global_memory["task_completed"] is False
assert _global_memory["completion_message"] == ""
def test_execute_test_command_wrapper(monkeypatch):
from ra_aid.agent_utils import _execute_test_command_wrapper
# Patch execute_test_command to return a testable tuple
def fake_execute(config, orig, tests, auto):
return (True, "new prompt", auto, tests + 1)
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
result = _execute_test_command_wrapper("orig", {}, 0, False)
assert result == (True, "new prompt", False, 1)
def test_handle_api_error_valueerror():
from ra_aid.agent_utils import _handle_api_error
import pytest
# ValueError not containing "code" or "429" should be re-raised
with pytest.raises(ValueError):
_handle_api_error(ValueError("some error"), 0, 5, 1)
def test_handle_api_error_max_retries():
from ra_aid.agent_utils import _handle_api_error
import pytest
# When attempt reaches max retries, a RuntimeError should be raised
with pytest.raises(RuntimeError):
_handle_api_error(Exception("error code 429"), 4, 5, 1)
def test_handle_api_error_retry(monkeypatch):
from ra_aid.agent_utils import _handle_api_error
import time
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
fake_time = [0]
def fake_monotonic():
fake_time[0] += 0.5
return fake_time[0]
monkeypatch.setattr(time, "monotonic", fake_monotonic)
monkeypatch.setattr(time, "sleep", lambda s: None)
# Should not raise error when attempt is lower than max retries
_handle_api_error(Exception("error code 429"), 0, 5, 1)