From 13880677694ea3d0133a7bc300879fdef08a32c1 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Tue, 11 Feb 2025 12:16:04 -0800 Subject: [PATCH] refactor(agent_utils.py): refactor run_agent_with_retry function for better readability and maintainability by extracting helper functions feat(agent_utils.py): add new helper functions for handling API errors and managing interrupt signals fix(agent_utils.py): improve error handling in tool execution and retry logic feat(fallback_handler.py): enhance fallback handling by binding tools correctly during retries test(tests): add unit tests for new helper functions and refactored logic in agent_utils.py --- ra_aid/agent_utils.py | 148 ++++++++++++++++++------------- ra_aid/agents/ciayn_agent.py | 58 +++++------- ra_aid/fallback_handler.py | 39 +++++++- tests/ra_aid/test_agent_utils.py | 78 ++++++++++++++++ 4 files changed, 223 insertions(+), 100 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 4dcc6a5..678ce73 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -9,10 +9,16 @@ import uuid from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Sequence +from langgraph.graph.graph import CompiledGraph import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage, HumanMessage, trim_messages +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + InvalidToolCall, + trim_messages, +) from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent @@ -26,7 +32,8 @@ from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output -from ra_aid.exceptions import AgentInterrupt +from ra_aid.exceptions import AgentInterrupt, ToolExecutionError +from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params from ra_aid.project_info import ( @@ -238,7 +245,7 @@ def create_agent( *, checkpointer: Any = None, agent_type: str = "default", -) -> Any: +): """Create a react agent with the given configuration. Args: @@ -775,61 +782,98 @@ def check_interrupt(): raise AgentInterrupt("Interrupt requested") -def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: - """Run an agent with retry logic for API errors.""" - logger.debug("Running agent with prompt length: %d", len(prompt)) - original_handler = None +# New helper functions for run_agent_with_retry refactoring +def _setup_interrupt_handling(): if threading.current_thread() is threading.main_thread(): original_handler = signal.getsignal(signal.SIGINT) signal.signal(signal.SIGINT, _request_interrupt) + return original_handler + return None + +def _restore_interrupt_handling(original_handler): + if original_handler and threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGINT, original_handler) + + +def _increment_agent_depth(): + current_depth = _global_memory.get("agent_depth", 0) + _global_memory["agent_depth"] = current_depth + 1 + + +def _decrement_agent_depth(): + _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 + + +def _run_agent_stream(agent, prompt, config): + for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): + logger.debug("Agent output: %s", chunk) + check_interrupt() + print_agent_output(chunk) + if _global_memory["plan_completed"] or _global_memory["task_completed"]: + _global_memory["plan_completed"] = False + _global_memory["task_completed"] = False + _global_memory["completion_message"] = "" + break + + +def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test): + return execute_test_command(config, original_prompt, test_attempts, auto_test) + + +def _handle_api_error(e, attempt, max_retries, base_delay): + if isinstance(e, ValueError): + error_str = str(e).lower() + if "code" not in error_str or "429" not in error_str: + raise e + if attempt == max_retries - 1: + logger.error("Max retries reached, failing: %s", str(e)) + raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}") + logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e)) + delay = base_delay * (2**attempt) + print_error( + f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})" + ) + start = time.monotonic() + while time.monotonic() - start < delay: + check_interrupt() + time.sleep(0.1) + + +def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: + """Run an agent with retry logic for API errors.""" + logger.debug("Running agent with prompt length: %d", len(prompt)) + original_handler = _setup_interrupt_handling() max_retries = 20 base_delay = 1 test_attempts = 0 _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt + fallback_handler = FallbackHandler(config) with InterruptibleSection(): try: - # Track agent execution depth - current_depth = _global_memory.get("agent_depth", 0) - _global_memory["agent_depth"] = current_depth + 1 - + _increment_agent_depth() for attempt in range(max_retries): logger.debug("Attempt %d/%d", attempt + 1, max_retries) check_interrupt() try: - for chunk in agent.stream( - {"messages": [HumanMessage(content=prompt)]}, config - ): - logger.debug("Agent output: %s", chunk) - check_interrupt() - print_agent_output(chunk) - - if _global_memory["plan_completed"]: - _global_memory["plan_completed"] = False - _global_memory["task_completed"] = False - _global_memory["completion_message"] = "" - break - if _global_memory["task_completed"]: - _global_memory["task_completed"] = False - _global_memory["completion_message"] = "" - break - - # Execute test command if configured + _run_agent_stream(agent, prompt, config) + fallback_handler.reset_fallback_handler() should_break, prompt, auto_test, test_attempts = ( - execute_test_command( - config, original_prompt, test_attempts, auto_test + _execute_test_command_wrapper( + original_prompt, config, test_attempts, auto_test ) ) if should_break: break if prompt != original_prompt: continue - logger.debug("Agent run completed successfully") return "Agent run completed successfully" + except (ToolExecutionError, InvalidToolCall) as e: + _handle_tool_execution_error(fallback_handler, agent, e) except (KeyboardInterrupt, AgentInterrupt): raise except ( @@ -839,35 +883,15 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: APIError, ValueError, ) as e: - if isinstance(e, ValueError): - error_str = str(e).lower() - if "code" not in error_str or "429" not in error_str: - raise # Re-raise ValueError if it's not a Lambda 429 - if attempt == max_retries - 1: - logger.error("Max retries reached, failing: %s", str(e)) - raise RuntimeError( - f"Max retries ({max_retries}) exceeded. Last error: {e}" - ) - logger.warning( - "API error (attempt %d/%d): %s", - attempt + 1, - max_retries, - str(e), - ) - delay = base_delay * (2**attempt) - print_error( - f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})" - ) - start = time.monotonic() - while time.monotonic() - start < delay: - check_interrupt() - time.sleep(0.1) + _handle_api_error(e, attempt, max_retries, base_delay) finally: - # Reset depth tracking - _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1 + _decrement_agent_depth() + _restore_interrupt_handling(original_handler) - if ( - original_handler - and threading.current_thread() is threading.main_thread() - ): - signal.signal(signal.SIGINT, original_handler) + +def _handle_tool_execution_error( + fallback_handler: FallbackHandler, + agent: CiaynAgent | CompiledGraph, + error: ToolExecutionError | InvalidToolCall, +): + fallback_handler.handle_failure("Tool execution error", error, agent) diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 57c7467..4060684 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Generator, List, Optional, Union from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from ra_aid.fallback_handler import FallbackHandler from ra_aid.exceptions import ToolExecutionError from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT @@ -84,13 +83,12 @@ class CiaynAgent: tools: List of tools available to the agent max_history_messages: Maximum number of messages to keep in chat history max_tokens: Maximum number of tokens allowed in message history (None for no limit) - config: Optional configuration dictionary for fallback settings + config: Optional configuration dictionary """ if config is None: config = {} self.config = config self.provider = config.get("provider", "openai") - self.fallback_handler = FallbackHandler(config) self.model = model self.tools = tools @@ -232,39 +230,29 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return base_prompt def _execute_tool(self, code: str) -> str: - """Execute a tool call with retry and fallback logic and return its result.""" - max_retries = 3 - retries = 0 - last_error = None - while retries < max_retries: - try: - logger.debug( - f"_execute_tool: attempt {retries+1}, original code: {code}" - ) - code = code.strip() - if validate_function_call_pattern(code): - functions_list = "\n\n".join(self.available_functions) - code = _extract_tool_call(code, functions_list) - globals_dict = {tool.func.__name__: tool.func for tool in self.tools} - logger.debug(f"_execute_tool: evaluating code: {code}") - result = eval(code, globals_dict) - logger.debug( - f"_execute_tool: tool executed successfully with result: {result}" - ) - self.fallback_handler.reset_fallback_handler() - return result - except Exception as e: - logger.debug(f"_execute_tool: exception caught: {e}") - self._handle_tool_failure(code, e) - last_error = e - retries += 1 - logger.debug(f"_execute_tool: retrying, new attempt count: {retries}") - raise ToolExecutionError( - f"Error executing code after {max_retries} attempts: {str(last_error)}" - ) + """Execute a tool call and return its result.""" + globals_dict = {tool.func.__name__: tool.func for tool in self.tools} - def _handle_tool_failure(self, code: str, error: Exception) -> None: - self.fallback_handler.handle_failure(code, error, logger, self) + try: + code = code.strip() + logger.debug(f"_execute_tool: stripped code: {code}") + + # if the eval fails, try to extract it via a model call + if validate_function_call_pattern(code): + functions_list = "\n\n".join(self.available_functions) + logger.debug(f"_execute_tool: code before extraction: {code}") + code = _extract_tool_call(code, functions_list) + logger.debug(f"_execute_tool: code after extraction: {code}") + + logger.debug( + f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}" + ) + result = eval(code.strip(), globals_dict) + logger.debug(f"_execute_tool: result: {result}") + return result + except Exception as e: + error_msg = f"Error executing code: {str(e)}" + raise ToolExecutionError(error_msg) def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index ebc3cb6..2b248a5 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -9,6 +9,8 @@ from rich.markdown import Markdown from rich.panel import Panel from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env +logger = get_logger(__name__) + class FallbackHandler: """ @@ -73,6 +75,7 @@ class FallbackHandler: for item in supported: if "type" not in item: item["type"] = "prompt" + item["model"] = item["model"].lower() final_models.append(item) message = "Fallback models selected: " + ", ".join( [m["model"] for m in final_models] @@ -85,7 +88,7 @@ class FallbackHandler: console.print(Panel(Markdown(message), title="Fallback Models")) return final_models - def handle_failure(self, code: str, error: Exception, logger, agent): + def handle_failure(self, code: str, error: Exception, agent): """ Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded. @@ -173,8 +176,23 @@ class FallbackHandler: simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) + tool_to_bind = next( + ( + t + for t in agent.tools + if t.func.__name__ == failed_tool_call_name + ), + None, + ) + if tool_to_bind is None: + logger.debug( + f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}" + ) + raise Exception( + f"Tool {failed_tool_call_name} not found in agent.tools" + ) binded_model = simple_model.bind_tools( - agent.tools, tool_choice=failed_tool_call_name + [tool_to_bind], tool_choice=failed_tool_call_name ) retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT @@ -221,8 +239,23 @@ class FallbackHandler: simple_model = initialize_llm( fallback_model["provider"], fallback_model["model"] ) + tool_to_bind = next( + ( + t + for t in agent.tools + if t.func.__name__ == failed_tool_call_name + ), + None, + ) + if tool_to_bind is None: + logger.debug( + f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}" + ) + raise Exception( + f"Tool {failed_tool_call_name} not found in agent.tools" + ) binded_model = simple_model.bind_tools( - agent.tools, tool_choice=failed_tool_call_name + [tool_to_bind], tool_choice=failed_tool_call_name ) retry_model = binded_model.with_retry( stop_after_attempt=RETRY_FALLBACK_COUNT diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 5e935ed..7a02a85 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -275,3 +275,81 @@ def test_get_model_token_limit_planner(mock_memory): mock_get_info.return_value = {"max_input_tokens": 120000} token_limit = get_model_token_limit(config, "planner") assert token_limit == 120000 + +# New tests for private helper methods in agent_utils.py + +def test_setup_and_restore_interrupt_handling(): + import signal, threading + from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt + original_handler = signal.getsignal(signal.SIGINT) + handler = _setup_interrupt_handling() + # Verify the SIGINT handler is set to _request_interrupt + assert signal.getsignal(signal.SIGINT) == _request_interrupt + _restore_interrupt_handling(handler) + # Verify the SIGINT handler is restored to the original + assert signal.getsignal(signal.SIGINT) == original_handler + +def test_increment_and_decrement_agent_depth(): + from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory + _global_memory["agent_depth"] = 10 + _increment_agent_depth() + assert _global_memory["agent_depth"] == 11 + _decrement_agent_depth() + assert _global_memory["agent_depth"] == 10 + +def test_run_agent_stream(monkeypatch): + from ra_aid.agent_utils import _run_agent_stream, _global_memory + # Create a dummy agent that yields one chunk + class DummyAgent: + def stream(self, msg, cfg): + yield {"content": "chunk1"} + dummy_agent = DummyAgent() + # Set flags so that _run_agent_stream will reset them + _global_memory["plan_completed"] = True + _global_memory["task_completed"] = True + _global_memory["completion_message"] = "existing" + call_flag = {"called": False} + def fake_print_agent_output(chunk): + call_flag["called"] = True + monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output) + _run_agent_stream(dummy_agent, "dummy prompt", {}) + assert call_flag["called"] + assert _global_memory["plan_completed"] is False + assert _global_memory["task_completed"] is False + assert _global_memory["completion_message"] == "" + +def test_execute_test_command_wrapper(monkeypatch): + from ra_aid.agent_utils import _execute_test_command_wrapper + # Patch execute_test_command to return a testable tuple + def fake_execute(config, orig, tests, auto): + return (True, "new prompt", auto, tests + 1) + monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute) + result = _execute_test_command_wrapper("orig", {}, 0, False) + assert result == (True, "new prompt", False, 1) + +def test_handle_api_error_valueerror(): + from ra_aid.agent_utils import _handle_api_error + import pytest + # ValueError not containing "code" or "429" should be re-raised + with pytest.raises(ValueError): + _handle_api_error(ValueError("some error"), 0, 5, 1) + +def test_handle_api_error_max_retries(): + from ra_aid.agent_utils import _handle_api_error + import pytest + # When attempt reaches max retries, a RuntimeError should be raised + with pytest.raises(RuntimeError): + _handle_api_error(Exception("error code 429"), 4, 5, 1) + +def test_handle_api_error_retry(monkeypatch): + from ra_aid.agent_utils import _handle_api_error + import time + # Patch time.monotonic and time.sleep to simulate immediate delay expiration + fake_time = [0] + def fake_monotonic(): + fake_time[0] += 0.5 + return fake_time[0] + monkeypatch.setattr(time, "monotonic", fake_monotonic) + monkeypatch.setattr(time, "sleep", lambda s: None) + # Should not raise error when attempt is lower than max retries + _handle_api_error(Exception("error code 429"), 0, 5, 1)