refactor(agent_utils.py): refactor run_agent_with_retry function for better readability and maintainability by extracting helper functions
feat(agent_utils.py): add new helper functions for handling API errors and managing interrupt signals fix(agent_utils.py): improve error handling in tool execution and retry logic feat(fallback_handler.py): enhance fallback handling by binding tools correctly during retries test(tests): add unit tests for new helper functions and refactored logic in agent_utils.py
This commit is contained in:
parent
de489584e5
commit
1388067769
|
|
@ -9,10 +9,16 @@ import uuid
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, trim_messages
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
InvalidToolCall,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
|
@ -26,7 +32,8 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
|
|||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||
from ra_aid.console.formatting import print_error, print_stage_header
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.project_info import (
|
||||
|
|
@ -238,7 +245,7 @@ def create_agent(
|
|||
*,
|
||||
checkpointer: Any = None,
|
||||
agent_type: str = "default",
|
||||
) -> Any:
|
||||
):
|
||||
"""Create a react agent with the given configuration.
|
||||
|
||||
Args:
|
||||
|
|
@ -775,61 +782,98 @@ def check_interrupt():
|
|||
raise AgentInterrupt("Interrupt requested")
|
||||
|
||||
|
||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||
"""Run an agent with retry logic for API errors."""
|
||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||
original_handler = None
|
||||
# New helper functions for run_agent_with_retry refactoring
|
||||
def _setup_interrupt_handling():
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
signal.signal(signal.SIGINT, _request_interrupt)
|
||||
return original_handler
|
||||
return None
|
||||
|
||||
|
||||
def _restore_interrupt_handling(original_handler):
|
||||
if original_handler and threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
|
||||
|
||||
def _increment_agent_depth():
|
||||
current_depth = _global_memory.get("agent_depth", 0)
|
||||
_global_memory["agent_depth"] = current_depth + 1
|
||||
|
||||
|
||||
def _decrement_agent_depth():
|
||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||
|
||||
|
||||
def _run_agent_stream(agent, prompt, config):
|
||||
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||
_global_memory["plan_completed"] = False
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["completion_message"] = ""
|
||||
break
|
||||
|
||||
|
||||
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
||||
return execute_test_command(config, original_prompt, test_attempts, auto_test)
|
||||
|
||||
|
||||
def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||
if isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
if "code" not in error_str or "429" not in error_str:
|
||||
raise e
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("Max retries reached, failing: %s", str(e))
|
||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
||||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||
delay = base_delay * (2**attempt)
|
||||
print_error(
|
||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
check_interrupt()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||
"""Run an agent with retry logic for API errors."""
|
||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||
original_handler = _setup_interrupt_handling()
|
||||
max_retries = 20
|
||||
base_delay = 1
|
||||
test_attempts = 0
|
||||
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
||||
auto_test = config.get("auto_test", False)
|
||||
original_prompt = prompt
|
||||
fallback_handler = FallbackHandler(config)
|
||||
|
||||
with InterruptibleSection():
|
||||
try:
|
||||
# Track agent execution depth
|
||||
current_depth = _global_memory.get("agent_depth", 0)
|
||||
_global_memory["agent_depth"] = current_depth + 1
|
||||
|
||||
_increment_agent_depth()
|
||||
for attempt in range(max_retries):
|
||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||
check_interrupt()
|
||||
try:
|
||||
for chunk in agent.stream(
|
||||
{"messages": [HumanMessage(content=prompt)]}, config
|
||||
):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
|
||||
if _global_memory["plan_completed"]:
|
||||
_global_memory["plan_completed"] = False
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["completion_message"] = ""
|
||||
break
|
||||
if _global_memory["task_completed"]:
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["completion_message"] = ""
|
||||
break
|
||||
|
||||
# Execute test command if configured
|
||||
_run_agent_stream(agent, prompt, config)
|
||||
fallback_handler.reset_fallback_handler()
|
||||
should_break, prompt, auto_test, test_attempts = (
|
||||
execute_test_command(
|
||||
config, original_prompt, test_attempts, auto_test
|
||||
_execute_test_command_wrapper(
|
||||
original_prompt, config, test_attempts, auto_test
|
||||
)
|
||||
)
|
||||
if should_break:
|
||||
break
|
||||
if prompt != original_prompt:
|
||||
continue
|
||||
|
||||
logger.debug("Agent run completed successfully")
|
||||
return "Agent run completed successfully"
|
||||
except (ToolExecutionError, InvalidToolCall) as e:
|
||||
_handle_tool_execution_error(fallback_handler, agent, e)
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
except (
|
||||
|
|
@ -839,35 +883,15 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
|||
APIError,
|
||||
ValueError,
|
||||
) as e:
|
||||
if isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
if "code" not in error_str or "429" not in error_str:
|
||||
raise # Re-raise ValueError if it's not a Lambda 429
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("Max retries reached, failing: %s", str(e))
|
||||
raise RuntimeError(
|
||||
f"Max retries ({max_retries}) exceeded. Last error: {e}"
|
||||
)
|
||||
logger.warning(
|
||||
"API error (attempt %d/%d): %s",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
str(e),
|
||||
)
|
||||
delay = base_delay * (2**attempt)
|
||||
print_error(
|
||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
check_interrupt()
|
||||
time.sleep(0.1)
|
||||
_handle_api_error(e, attempt, max_retries, base_delay)
|
||||
finally:
|
||||
# Reset depth tracking
|
||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||
_decrement_agent_depth()
|
||||
_restore_interrupt_handling(original_handler)
|
||||
|
||||
if (
|
||||
original_handler
|
||||
and threading.current_thread() is threading.main_thread()
|
||||
|
||||
def _handle_tool_execution_error(
|
||||
fallback_handler: FallbackHandler,
|
||||
agent: CiaynAgent | CompiledGraph,
|
||||
error: ToolExecutionError | InvalidToolCall,
|
||||
):
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
fallback_handler.handle_failure("Tool execution error", error, agent)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
|||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
|
|
@ -84,13 +83,12 @@ class CiaynAgent:
|
|||
tools: List of tools available to the agent
|
||||
max_history_messages: Maximum number of messages to keep in chat history
|
||||
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
||||
config: Optional configuration dictionary for fallback settings
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
if config is None:
|
||||
config = {}
|
||||
self.config = config
|
||||
self.provider = config.get("provider", "openai")
|
||||
self.fallback_handler = FallbackHandler(config)
|
||||
|
||||
self.model = model
|
||||
self.tools = tools
|
||||
|
|
@ -232,39 +230,29 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
return base_prompt
|
||||
|
||||
def _execute_tool(self, code: str) -> str:
|
||||
"""Execute a tool call with retry and fallback logic and return its result."""
|
||||
max_retries = 3
|
||||
retries = 0
|
||||
last_error = None
|
||||
while retries < max_retries:
|
||||
"""Execute a tool call and return its result."""
|
||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"_execute_tool: attempt {retries+1}, original code: {code}"
|
||||
)
|
||||
code = code.strip()
|
||||
logger.debug(f"_execute_tool: stripped code: {code}")
|
||||
|
||||
# if the eval fails, try to extract it via a model call
|
||||
if validate_function_call_pattern(code):
|
||||
functions_list = "\n\n".join(self.available_functions)
|
||||
logger.debug(f"_execute_tool: code before extraction: {code}")
|
||||
code = _extract_tool_call(code, functions_list)
|
||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||
logger.debug(f"_execute_tool: evaluating code: {code}")
|
||||
result = eval(code, globals_dict)
|
||||
logger.debug(f"_execute_tool: code after extraction: {code}")
|
||||
|
||||
logger.debug(
|
||||
f"_execute_tool: tool executed successfully with result: {result}"
|
||||
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
|
||||
)
|
||||
self.fallback_handler.reset_fallback_handler()
|
||||
result = eval(code.strip(), globals_dict)
|
||||
logger.debug(f"_execute_tool: result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"_execute_tool: exception caught: {e}")
|
||||
self._handle_tool_failure(code, e)
|
||||
last_error = e
|
||||
retries += 1
|
||||
logger.debug(f"_execute_tool: retrying, new attempt count: {retries}")
|
||||
raise ToolExecutionError(
|
||||
f"Error executing code after {max_retries} attempts: {str(last_error)}"
|
||||
)
|
||||
|
||||
def _handle_tool_failure(self, code: str, error: Exception) -> None:
|
||||
self.fallback_handler.handle_failure(code, error, logger, self)
|
||||
error_msg = f"Error executing code: {str(e)}"
|
||||
raise ToolExecutionError(error_msg)
|
||||
|
||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FallbackHandler:
|
||||
"""
|
||||
|
|
@ -73,6 +75,7 @@ class FallbackHandler:
|
|||
for item in supported:
|
||||
if "type" not in item:
|
||||
item["type"] = "prompt"
|
||||
item["model"] = item["model"].lower()
|
||||
final_models.append(item)
|
||||
message = "Fallback models selected: " + ", ".join(
|
||||
[m["model"] for m in final_models]
|
||||
|
|
@ -85,7 +88,7 @@ class FallbackHandler:
|
|||
console.print(Panel(Markdown(message), title="Fallback Models"))
|
||||
return final_models
|
||||
|
||||
def handle_failure(self, code: str, error: Exception, logger, agent):
|
||||
def handle_failure(self, code: str, error: Exception, agent):
|
||||
"""
|
||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||
|
||||
|
|
@ -173,8 +176,23 @@ class FallbackHandler:
|
|||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
tool_to_bind = next(
|
||||
(
|
||||
t
|
||||
for t in agent.tools
|
||||
if t.func.__name__ == failed_tool_call_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_to_bind is None:
|
||||
logger.debug(
|
||||
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Tool {failed_tool_call_name} not found in agent.tools"
|
||||
)
|
||||
binded_model = simple_model.bind_tools(
|
||||
agent.tools, tool_choice=failed_tool_call_name
|
||||
[tool_to_bind], tool_choice=failed_tool_call_name
|
||||
)
|
||||
retry_model = binded_model.with_retry(
|
||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
|
|
@ -221,8 +239,23 @@ class FallbackHandler:
|
|||
simple_model = initialize_llm(
|
||||
fallback_model["provider"], fallback_model["model"]
|
||||
)
|
||||
tool_to_bind = next(
|
||||
(
|
||||
t
|
||||
for t in agent.tools
|
||||
if t.func.__name__ == failed_tool_call_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_to_bind is None:
|
||||
logger.debug(
|
||||
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Tool {failed_tool_call_name} not found in agent.tools"
|
||||
)
|
||||
binded_model = simple_model.bind_tools(
|
||||
agent.tools, tool_choice=failed_tool_call_name
|
||||
[tool_to_bind], tool_choice=failed_tool_call_name
|
||||
)
|
||||
retry_model = binded_model.with_retry(
|
||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||
|
|
|
|||
|
|
@ -275,3 +275,81 @@ def test_get_model_token_limit_planner(mock_memory):
|
|||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
assert token_limit == 120000
|
||||
|
||||
# New tests for private helper methods in agent_utils.py
|
||||
|
||||
def test_setup_and_restore_interrupt_handling():
|
||||
import signal, threading
|
||||
from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
handler = _setup_interrupt_handling()
|
||||
# Verify the SIGINT handler is set to _request_interrupt
|
||||
assert signal.getsignal(signal.SIGINT) == _request_interrupt
|
||||
_restore_interrupt_handling(handler)
|
||||
# Verify the SIGINT handler is restored to the original
|
||||
assert signal.getsignal(signal.SIGINT) == original_handler
|
||||
|
||||
def test_increment_and_decrement_agent_depth():
|
||||
from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory
|
||||
_global_memory["agent_depth"] = 10
|
||||
_increment_agent_depth()
|
||||
assert _global_memory["agent_depth"] == 11
|
||||
_decrement_agent_depth()
|
||||
assert _global_memory["agent_depth"] == 10
|
||||
|
||||
def test_run_agent_stream(monkeypatch):
|
||||
from ra_aid.agent_utils import _run_agent_stream, _global_memory
|
||||
# Create a dummy agent that yields one chunk
|
||||
class DummyAgent:
|
||||
def stream(self, msg, cfg):
|
||||
yield {"content": "chunk1"}
|
||||
dummy_agent = DummyAgent()
|
||||
# Set flags so that _run_agent_stream will reset them
|
||||
_global_memory["plan_completed"] = True
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = "existing"
|
||||
call_flag = {"called": False}
|
||||
def fake_print_agent_output(chunk):
|
||||
call_flag["called"] = True
|
||||
monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output)
|
||||
_run_agent_stream(dummy_agent, "dummy prompt", {})
|
||||
assert call_flag["called"]
|
||||
assert _global_memory["plan_completed"] is False
|
||||
assert _global_memory["task_completed"] is False
|
||||
assert _global_memory["completion_message"] == ""
|
||||
|
||||
def test_execute_test_command_wrapper(monkeypatch):
|
||||
from ra_aid.agent_utils import _execute_test_command_wrapper
|
||||
# Patch execute_test_command to return a testable tuple
|
||||
def fake_execute(config, orig, tests, auto):
|
||||
return (True, "new prompt", auto, tests + 1)
|
||||
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
|
||||
result = _execute_test_command_wrapper("orig", {}, 0, False)
|
||||
assert result == (True, "new prompt", False, 1)
|
||||
|
||||
def test_handle_api_error_valueerror():
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
import pytest
|
||||
# ValueError not containing "code" or "429" should be re-raised
|
||||
with pytest.raises(ValueError):
|
||||
_handle_api_error(ValueError("some error"), 0, 5, 1)
|
||||
|
||||
def test_handle_api_error_max_retries():
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
import pytest
|
||||
# When attempt reaches max retries, a RuntimeError should be raised
|
||||
with pytest.raises(RuntimeError):
|
||||
_handle_api_error(Exception("error code 429"), 4, 5, 1)
|
||||
|
||||
def test_handle_api_error_retry(monkeypatch):
|
||||
from ra_aid.agent_utils import _handle_api_error
|
||||
import time
|
||||
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
|
||||
fake_time = [0]
|
||||
def fake_monotonic():
|
||||
fake_time[0] += 0.5
|
||||
return fake_time[0]
|
||||
monkeypatch.setattr(time, "monotonic", fake_monotonic)
|
||||
monkeypatch.setattr(time, "sleep", lambda s: None)
|
||||
# Should not raise error when attempt is lower than max retries
|
||||
_handle_api_error(Exception("error code 429"), 0, 5, 1)
|
||||
|
|
|
|||
Loading…
Reference in New Issue