refactor(agent_utils.py): refactor run_agent_with_retry function for better readability and maintainability by extracting helper functions
feat(agent_utils.py): add new helper functions for handling API errors and managing interrupt signals fix(agent_utils.py): improve error handling in tool execution and retry logic feat(fallback_handler.py): enhance fallback handling by binding tools correctly during retries test(tests): add unit tests for new helper functions and refactored logic in agent_utils.py
This commit is contained in:
parent
de489584e5
commit
1388067769
|
|
@ -9,10 +9,16 @@ import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
|
from langgraph.graph.graph import CompiledGraph
|
||||||
import litellm
|
import litellm
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage, trim_messages
|
from langchain_core.messages import (
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
InvalidToolCall,
|
||||||
|
trim_messages,
|
||||||
|
)
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
|
@ -26,7 +32,8 @@ from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||||
from ra_aid.console.formatting import print_error, print_stage_header
|
from ra_aid.console.formatting import print_error, print_stage_header
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
|
||||||
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
from ra_aid.project_info import (
|
from ra_aid.project_info import (
|
||||||
|
|
@ -238,7 +245,7 @@ def create_agent(
|
||||||
*,
|
*,
|
||||||
checkpointer: Any = None,
|
checkpointer: Any = None,
|
||||||
agent_type: str = "default",
|
agent_type: str = "default",
|
||||||
) -> Any:
|
):
|
||||||
"""Create a react agent with the given configuration.
|
"""Create a react agent with the given configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -775,61 +782,98 @@ def check_interrupt():
|
||||||
raise AgentInterrupt("Interrupt requested")
|
raise AgentInterrupt("Interrupt requested")
|
||||||
|
|
||||||
|
|
||||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
# New helper functions for run_agent_with_retry refactoring
|
||||||
"""Run an agent with retry logic for API errors."""
|
def _setup_interrupt_handling():
|
||||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
|
||||||
original_handler = None
|
|
||||||
if threading.current_thread() is threading.main_thread():
|
if threading.current_thread() is threading.main_thread():
|
||||||
original_handler = signal.getsignal(signal.SIGINT)
|
original_handler = signal.getsignal(signal.SIGINT)
|
||||||
signal.signal(signal.SIGINT, _request_interrupt)
|
signal.signal(signal.SIGINT, _request_interrupt)
|
||||||
|
return original_handler
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _restore_interrupt_handling(original_handler):
|
||||||
|
if original_handler and threading.current_thread() is threading.main_thread():
|
||||||
|
signal.signal(signal.SIGINT, original_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def _increment_agent_depth():
|
||||||
|
current_depth = _global_memory.get("agent_depth", 0)
|
||||||
|
_global_memory["agent_depth"] = current_depth + 1
|
||||||
|
|
||||||
|
|
||||||
|
def _decrement_agent_depth():
|
||||||
|
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def _run_agent_stream(agent, prompt, config):
|
||||||
|
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
||||||
|
logger.debug("Agent output: %s", chunk)
|
||||||
|
check_interrupt()
|
||||||
|
print_agent_output(chunk)
|
||||||
|
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||||
|
_global_memory["plan_completed"] = False
|
||||||
|
_global_memory["task_completed"] = False
|
||||||
|
_global_memory["completion_message"] = ""
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
|
||||||
|
return execute_test_command(config, original_prompt, test_attempts, auto_test)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||||
|
if isinstance(e, ValueError):
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "code" not in error_str or "429" not in error_str:
|
||||||
|
raise e
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
logger.error("Max retries reached, failing: %s", str(e))
|
||||||
|
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
||||||
|
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||||
|
delay = base_delay * (2**attempt)
|
||||||
|
print_error(
|
||||||
|
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||||
|
)
|
||||||
|
start = time.monotonic()
|
||||||
|
while time.monotonic() - start < delay:
|
||||||
|
check_interrupt()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||||
|
"""Run an agent with retry logic for API errors."""
|
||||||
|
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||||
|
original_handler = _setup_interrupt_handling()
|
||||||
max_retries = 20
|
max_retries = 20
|
||||||
base_delay = 1
|
base_delay = 1
|
||||||
test_attempts = 0
|
test_attempts = 0
|
||||||
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
||||||
auto_test = config.get("auto_test", False)
|
auto_test = config.get("auto_test", False)
|
||||||
original_prompt = prompt
|
original_prompt = prompt
|
||||||
|
fallback_handler = FallbackHandler(config)
|
||||||
|
|
||||||
with InterruptibleSection():
|
with InterruptibleSection():
|
||||||
try:
|
try:
|
||||||
# Track agent execution depth
|
_increment_agent_depth()
|
||||||
current_depth = _global_memory.get("agent_depth", 0)
|
|
||||||
_global_memory["agent_depth"] = current_depth + 1
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
try:
|
try:
|
||||||
for chunk in agent.stream(
|
_run_agent_stream(agent, prompt, config)
|
||||||
{"messages": [HumanMessage(content=prompt)]}, config
|
fallback_handler.reset_fallback_handler()
|
||||||
):
|
|
||||||
logger.debug("Agent output: %s", chunk)
|
|
||||||
check_interrupt()
|
|
||||||
print_agent_output(chunk)
|
|
||||||
|
|
||||||
if _global_memory["plan_completed"]:
|
|
||||||
_global_memory["plan_completed"] = False
|
|
||||||
_global_memory["task_completed"] = False
|
|
||||||
_global_memory["completion_message"] = ""
|
|
||||||
break
|
|
||||||
if _global_memory["task_completed"]:
|
|
||||||
_global_memory["task_completed"] = False
|
|
||||||
_global_memory["completion_message"] = ""
|
|
||||||
break
|
|
||||||
|
|
||||||
# Execute test command if configured
|
|
||||||
should_break, prompt, auto_test, test_attempts = (
|
should_break, prompt, auto_test, test_attempts = (
|
||||||
execute_test_command(
|
_execute_test_command_wrapper(
|
||||||
config, original_prompt, test_attempts, auto_test
|
original_prompt, config, test_attempts, auto_test
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if should_break:
|
if should_break:
|
||||||
break
|
break
|
||||||
if prompt != original_prompt:
|
if prompt != original_prompt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug("Agent run completed successfully")
|
logger.debug("Agent run completed successfully")
|
||||||
return "Agent run completed successfully"
|
return "Agent run completed successfully"
|
||||||
|
except (ToolExecutionError, InvalidToolCall) as e:
|
||||||
|
_handle_tool_execution_error(fallback_handler, agent, e)
|
||||||
except (KeyboardInterrupt, AgentInterrupt):
|
except (KeyboardInterrupt, AgentInterrupt):
|
||||||
raise
|
raise
|
||||||
except (
|
except (
|
||||||
|
|
@ -839,35 +883,15 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||||
APIError,
|
APIError,
|
||||||
ValueError,
|
ValueError,
|
||||||
) as e:
|
) as e:
|
||||||
if isinstance(e, ValueError):
|
_handle_api_error(e, attempt, max_retries, base_delay)
|
||||||
error_str = str(e).lower()
|
|
||||||
if "code" not in error_str or "429" not in error_str:
|
|
||||||
raise # Re-raise ValueError if it's not a Lambda 429
|
|
||||||
if attempt == max_retries - 1:
|
|
||||||
logger.error("Max retries reached, failing: %s", str(e))
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Max retries ({max_retries}) exceeded. Last error: {e}"
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
"API error (attempt %d/%d): %s",
|
|
||||||
attempt + 1,
|
|
||||||
max_retries,
|
|
||||||
str(e),
|
|
||||||
)
|
|
||||||
delay = base_delay * (2**attempt)
|
|
||||||
print_error(
|
|
||||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
|
||||||
)
|
|
||||||
start = time.monotonic()
|
|
||||||
while time.monotonic() - start < delay:
|
|
||||||
check_interrupt()
|
|
||||||
time.sleep(0.1)
|
|
||||||
finally:
|
finally:
|
||||||
# Reset depth tracking
|
_decrement_agent_depth()
|
||||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
_restore_interrupt_handling(original_handler)
|
||||||
|
|
||||||
if (
|
|
||||||
original_handler
|
def _handle_tool_execution_error(
|
||||||
and threading.current_thread() is threading.main_thread()
|
fallback_handler: FallbackHandler,
|
||||||
):
|
agent: CiaynAgent | CompiledGraph,
|
||||||
signal.signal(signal.SIGINT, original_handler)
|
error: ToolExecutionError | InvalidToolCall,
|
||||||
|
):
|
||||||
|
fallback_handler.handle_failure("Tool execution error", error, agent)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
from ra_aid.fallback_handler import FallbackHandler
|
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||||
|
|
@ -84,13 +83,12 @@ class CiaynAgent:
|
||||||
tools: List of tools available to the agent
|
tools: List of tools available to the agent
|
||||||
max_history_messages: Maximum number of messages to keep in chat history
|
max_history_messages: Maximum number of messages to keep in chat history
|
||||||
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
||||||
config: Optional configuration dictionary for fallback settings
|
config: Optional configuration dictionary
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = {}
|
config = {}
|
||||||
self.config = config
|
self.config = config
|
||||||
self.provider = config.get("provider", "openai")
|
self.provider = config.get("provider", "openai")
|
||||||
self.fallback_handler = FallbackHandler(config)
|
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
@ -232,39 +230,29 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
return base_prompt
|
return base_prompt
|
||||||
|
|
||||||
def _execute_tool(self, code: str) -> str:
|
def _execute_tool(self, code: str) -> str:
|
||||||
"""Execute a tool call with retry and fallback logic and return its result."""
|
"""Execute a tool call and return its result."""
|
||||||
max_retries = 3
|
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||||
retries = 0
|
|
||||||
last_error = None
|
|
||||||
while retries < max_retries:
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"_execute_tool: attempt {retries+1}, original code: {code}"
|
|
||||||
)
|
|
||||||
code = code.strip()
|
|
||||||
if validate_function_call_pattern(code):
|
|
||||||
functions_list = "\n\n".join(self.available_functions)
|
|
||||||
code = _extract_tool_call(code, functions_list)
|
|
||||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
|
||||||
logger.debug(f"_execute_tool: evaluating code: {code}")
|
|
||||||
result = eval(code, globals_dict)
|
|
||||||
logger.debug(
|
|
||||||
f"_execute_tool: tool executed successfully with result: {result}"
|
|
||||||
)
|
|
||||||
self.fallback_handler.reset_fallback_handler()
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"_execute_tool: exception caught: {e}")
|
|
||||||
self._handle_tool_failure(code, e)
|
|
||||||
last_error = e
|
|
||||||
retries += 1
|
|
||||||
logger.debug(f"_execute_tool: retrying, new attempt count: {retries}")
|
|
||||||
raise ToolExecutionError(
|
|
||||||
f"Error executing code after {max_retries} attempts: {str(last_error)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_tool_failure(self, code: str, error: Exception) -> None:
|
try:
|
||||||
self.fallback_handler.handle_failure(code, error, logger, self)
|
code = code.strip()
|
||||||
|
logger.debug(f"_execute_tool: stripped code: {code}")
|
||||||
|
|
||||||
|
# if the eval fails, try to extract it via a model call
|
||||||
|
if validate_function_call_pattern(code):
|
||||||
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
|
logger.debug(f"_execute_tool: code before extraction: {code}")
|
||||||
|
code = _extract_tool_call(code, functions_list)
|
||||||
|
logger.debug(f"_execute_tool: code after extraction: {code}")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
|
||||||
|
)
|
||||||
|
result = eval(code.strip(), globals_dict)
|
||||||
|
logger.debug(f"_execute_tool: result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error executing code: {str(e)}"
|
||||||
|
raise ToolExecutionError(error_msg)
|
||||||
|
|
||||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FallbackHandler:
|
class FallbackHandler:
|
||||||
"""
|
"""
|
||||||
|
|
@ -73,6 +75,7 @@ class FallbackHandler:
|
||||||
for item in supported:
|
for item in supported:
|
||||||
if "type" not in item:
|
if "type" not in item:
|
||||||
item["type"] = "prompt"
|
item["type"] = "prompt"
|
||||||
|
item["model"] = item["model"].lower()
|
||||||
final_models.append(item)
|
final_models.append(item)
|
||||||
message = "Fallback models selected: " + ", ".join(
|
message = "Fallback models selected: " + ", ".join(
|
||||||
[m["model"] for m in final_models]
|
[m["model"] for m in final_models]
|
||||||
|
|
@ -85,7 +88,7 @@ class FallbackHandler:
|
||||||
console.print(Panel(Markdown(message), title="Fallback Models"))
|
console.print(Panel(Markdown(message), title="Fallback Models"))
|
||||||
return final_models
|
return final_models
|
||||||
|
|
||||||
def handle_failure(self, code: str, error: Exception, logger, agent):
|
def handle_failure(self, code: str, error: Exception, agent):
|
||||||
"""
|
"""
|
||||||
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
|
||||||
|
|
||||||
|
|
@ -173,8 +176,23 @@ class FallbackHandler:
|
||||||
simple_model = initialize_llm(
|
simple_model = initialize_llm(
|
||||||
fallback_model["provider"], fallback_model["model"]
|
fallback_model["provider"], fallback_model["model"]
|
||||||
)
|
)
|
||||||
|
tool_to_bind = next(
|
||||||
|
(
|
||||||
|
t
|
||||||
|
for t in agent.tools
|
||||||
|
if t.func.__name__ == failed_tool_call_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if tool_to_bind is None:
|
||||||
|
logger.debug(
|
||||||
|
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
f"Tool {failed_tool_call_name} not found in agent.tools"
|
||||||
|
)
|
||||||
binded_model = simple_model.bind_tools(
|
binded_model = simple_model.bind_tools(
|
||||||
agent.tools, tool_choice=failed_tool_call_name
|
[tool_to_bind], tool_choice=failed_tool_call_name
|
||||||
)
|
)
|
||||||
retry_model = binded_model.with_retry(
|
retry_model = binded_model.with_retry(
|
||||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
|
|
@ -221,8 +239,23 @@ class FallbackHandler:
|
||||||
simple_model = initialize_llm(
|
simple_model = initialize_llm(
|
||||||
fallback_model["provider"], fallback_model["model"]
|
fallback_model["provider"], fallback_model["model"]
|
||||||
)
|
)
|
||||||
|
tool_to_bind = next(
|
||||||
|
(
|
||||||
|
t
|
||||||
|
for t in agent.tools
|
||||||
|
if t.func.__name__ == failed_tool_call_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if tool_to_bind is None:
|
||||||
|
logger.debug(
|
||||||
|
f"Failed to find tool: {failed_tool_call_name}. Available tools: {[t.func.__name__ for t in agent.tools]}"
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
f"Tool {failed_tool_call_name} not found in agent.tools"
|
||||||
|
)
|
||||||
binded_model = simple_model.bind_tools(
|
binded_model = simple_model.bind_tools(
|
||||||
agent.tools, tool_choice=failed_tool_call_name
|
[tool_to_bind], tool_choice=failed_tool_call_name
|
||||||
)
|
)
|
||||||
retry_model = binded_model.with_retry(
|
retry_model = binded_model.with_retry(
|
||||||
stop_after_attempt=RETRY_FALLBACK_COUNT
|
stop_after_attempt=RETRY_FALLBACK_COUNT
|
||||||
|
|
|
||||||
|
|
@ -275,3 +275,81 @@ def test_get_model_token_limit_planner(mock_memory):
|
||||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||||
token_limit = get_model_token_limit(config, "planner")
|
token_limit = get_model_token_limit(config, "planner")
|
||||||
assert token_limit == 120000
|
assert token_limit == 120000
|
||||||
|
|
||||||
|
# New tests for private helper methods in agent_utils.py
|
||||||
|
|
||||||
|
def test_setup_and_restore_interrupt_handling():
|
||||||
|
import signal, threading
|
||||||
|
from ra_aid.agent_utils import _setup_interrupt_handling, _restore_interrupt_handling, _request_interrupt
|
||||||
|
original_handler = signal.getsignal(signal.SIGINT)
|
||||||
|
handler = _setup_interrupt_handling()
|
||||||
|
# Verify the SIGINT handler is set to _request_interrupt
|
||||||
|
assert signal.getsignal(signal.SIGINT) == _request_interrupt
|
||||||
|
_restore_interrupt_handling(handler)
|
||||||
|
# Verify the SIGINT handler is restored to the original
|
||||||
|
assert signal.getsignal(signal.SIGINT) == original_handler
|
||||||
|
|
||||||
|
def test_increment_and_decrement_agent_depth():
|
||||||
|
from ra_aid.agent_utils import _increment_agent_depth, _decrement_agent_depth, _global_memory
|
||||||
|
_global_memory["agent_depth"] = 10
|
||||||
|
_increment_agent_depth()
|
||||||
|
assert _global_memory["agent_depth"] == 11
|
||||||
|
_decrement_agent_depth()
|
||||||
|
assert _global_memory["agent_depth"] == 10
|
||||||
|
|
||||||
|
def test_run_agent_stream(monkeypatch):
|
||||||
|
from ra_aid.agent_utils import _run_agent_stream, _global_memory
|
||||||
|
# Create a dummy agent that yields one chunk
|
||||||
|
class DummyAgent:
|
||||||
|
def stream(self, msg, cfg):
|
||||||
|
yield {"content": "chunk1"}
|
||||||
|
dummy_agent = DummyAgent()
|
||||||
|
# Set flags so that _run_agent_stream will reset them
|
||||||
|
_global_memory["plan_completed"] = True
|
||||||
|
_global_memory["task_completed"] = True
|
||||||
|
_global_memory["completion_message"] = "existing"
|
||||||
|
call_flag = {"called": False}
|
||||||
|
def fake_print_agent_output(chunk):
|
||||||
|
call_flag["called"] = True
|
||||||
|
monkeypatch.setattr("ra_aid.agent_utils.print_agent_output", fake_print_agent_output)
|
||||||
|
_run_agent_stream(dummy_agent, "dummy prompt", {})
|
||||||
|
assert call_flag["called"]
|
||||||
|
assert _global_memory["plan_completed"] is False
|
||||||
|
assert _global_memory["task_completed"] is False
|
||||||
|
assert _global_memory["completion_message"] == ""
|
||||||
|
|
||||||
|
def test_execute_test_command_wrapper(monkeypatch):
|
||||||
|
from ra_aid.agent_utils import _execute_test_command_wrapper
|
||||||
|
# Patch execute_test_command to return a testable tuple
|
||||||
|
def fake_execute(config, orig, tests, auto):
|
||||||
|
return (True, "new prompt", auto, tests + 1)
|
||||||
|
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
|
||||||
|
result = _execute_test_command_wrapper("orig", {}, 0, False)
|
||||||
|
assert result == (True, "new prompt", False, 1)
|
||||||
|
|
||||||
|
def test_handle_api_error_valueerror():
|
||||||
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
import pytest
|
||||||
|
# ValueError not containing "code" or "429" should be re-raised
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_handle_api_error(ValueError("some error"), 0, 5, 1)
|
||||||
|
|
||||||
|
def test_handle_api_error_max_retries():
|
||||||
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
import pytest
|
||||||
|
# When attempt reaches max retries, a RuntimeError should be raised
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
_handle_api_error(Exception("error code 429"), 4, 5, 1)
|
||||||
|
|
||||||
|
def test_handle_api_error_retry(monkeypatch):
|
||||||
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
import time
|
||||||
|
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
|
||||||
|
fake_time = [0]
|
||||||
|
def fake_monotonic():
|
||||||
|
fake_time[0] += 0.5
|
||||||
|
return fake_time[0]
|
||||||
|
monkeypatch.setattr(time, "monotonic", fake_monotonic)
|
||||||
|
monkeypatch.setattr(time, "sleep", lambda s: None)
|
||||||
|
# Should not raise error when attempt is lower than max retries
|
||||||
|
_handle_api_error(Exception("error code 429"), 0, 5, 1)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue