RA.Aid/ra_aid/agent_utils.py

504 lines
18 KiB
Python

"""Utility functions for working with agents."""
import signal
import sys
import threading
import time
from typing import Any, Dict, List, Literal, Optional
from langchain_anthropic import ChatAnthropic
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from openai import RateLimitError as OpenAIRateLimitError
from litellm.exceptions import RateLimitError as LiteLLMRateLimitError
from google.api_core.exceptions import ResourceExhausted
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentState
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.agent_context import (
agent_context,
is_completed,
reset_completion_flags,
should_exit,
)
from ra_aid.agent_backends.ciayn_agent import CiaynAgent
from ra_aid.agents_alias import RAgents
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES
from ra_aid.console.formatting import print_error
from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import (
AgentInterrupt,
FallbackToolExecutionError,
ToolExecutionError,
)
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.anthropic_token_limiter import state_modifier, get_model_token_limit
console = Console()
logger = get_logger(__name__)
# Import repositories using get_* functions
@tool
def output_markdown_message(message: str) -> str:
"""Outputs a message to the user, optionally prompting for input."""
console.print(Panel(Markdown(message.strip()), title="🤖 Assistant"))
return "Message output."
def build_agent_kwargs(
checkpointer: Optional[Any] = None,
model: ChatAnthropic = None,
max_input_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""Build kwargs dictionary for agent creation.
Args:
checkpointer: Optional memory checkpointer
model: The language model to use for token counting
max_input_tokens: Optional token limit for the model
Returns:
Dictionary of kwargs for agent creation
"""
agent_kwargs = {
"version": "v2",
}
if checkpointer is not None:
agent_kwargs["checkpointer"] = checkpointer
config = get_config_repository().get_all()
if (
config.get("limit_tokens", True)
and is_anthropic_claude(config)
and model is not None
):
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
return state_modifier(state, model, max_input_tokens=max_input_tokens)
agent_kwargs["state_modifier"] = wrapped_state_modifier
agent_kwargs["name"] = "React"
return agent_kwargs
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
"""Check if the provider and model name indicate an Anthropic Claude model.
Args:
config: Configuration dictionary containing provider and model information
Returns:
bool: True if this is an Anthropic Claude model
"""
# For backwards compatibility, allow passing of config directly
provider = config.get("provider", "")
model_name = config.get("model", "")
result = (
provider.lower() == "anthropic"
and model_name
and "claude" in model_name.lower()
) or (
provider.lower() == "openrouter"
and model_name.lower().startswith("anthropic/claude-")
)
return result
def create_agent(
model: BaseChatModel,
tools: List[Any],
*,
checkpointer: Any = None,
agent_type: str = "default",
):
"""Create a react agent with the given configuration.
Args:
model: The LLM model to use
tools: List of tools to provide to the agent
checkpointer: Optional memory checkpointer
config: Optional configuration dictionary containing settings like:
- limit_tokens (bool): Whether to apply token limiting (default: True)
- provider (str): The LLM provider name
- model (str): The model name
Returns:
The created agent instance
Token limiting helps prevent context window overflow by trimming older messages
while preserving system messages. It can be disabled by setting
config['limit_tokens'] = False.
"""
try:
# Try to get config from repository for production use
try:
config_from_repo = get_config_repository().get_all()
# If we succeeded, use the repository config instead of passed config
config = config_from_repo
except RuntimeError:
# In tests, this may fail because the repository isn't set up
# So we'll use the passed config directly
pass
max_input_tokens = (
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
)
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
else:
logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
except Exception as e:
# Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, model, max_input_tokens)
return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
_CONTEXT_STACK = []
_INTERRUPT_CONTEXT = None
_FEEDBACK_MODE = False
def _request_interrupt(signum, frame):
global _INTERRUPT_CONTEXT
if _CONTEXT_STACK:
_INTERRUPT_CONTEXT = _CONTEXT_STACK[-1]
if _FEEDBACK_MODE:
print()
print(" 👋 Bye!")
print()
sys.exit(0)
class InterruptibleSection:
def __enter__(self):
_CONTEXT_STACK.append(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
_CONTEXT_STACK.remove(self)
def check_interrupt():
if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]:
raise AgentInterrupt("Interrupt requested")
# New helper functions for run_agent_with_retry refactoring
def _setup_interrupt_handling():
if threading.current_thread() is threading.main_thread():
original_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, _request_interrupt)
return original_handler
return None
def _restore_interrupt_handling(original_handler):
if original_handler and threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, original_handler)
def reset_agent_completion_flags():
"""Reset completion flags in the current context."""
reset_completion_flags()
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
# For backwards compatibility, allow passing of config directly
# No need to get config from repository as it's passed in
return execute_test_command(config, original_prompt, test_attempts, auto_test)
def _handle_api_error(e, attempt, max_retries, base_delay):
# 1. Check if this is a ValueError with 429 code or rate limit phrases
if isinstance(e, ValueError):
error_str = str(e).lower()
rate_limit_phrases = [
"429",
"rate limit",
"too many requests",
"quota exceeded",
]
if "code" not in error_str and not any(
phrase in error_str for phrase in rate_limit_phrases
):
raise e
# 2. Check for status_code or http_status attribute equal to 429
if hasattr(e, "status_code") and e.status_code == 429:
pass # This is a rate limit error, continue with retry logic
elif hasattr(e, "http_status") and e.http_status == 429:
pass # This is a rate limit error, continue with retry logic
# 3. Check for rate limit phrases in error message
elif isinstance(e, Exception) and not isinstance(e, ValueError):
error_str = str(e).lower()
if not any(
phrase in error_str
for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]
) and not ("rate" in error_str and "limit" in error_str):
# This doesn't look like a rate limit error, but we'll still retry other API errors
pass
# Apply common retry logic for all identified errors
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
"""
Determines the type of the agent.
Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React".
"""
if isinstance(agent, CiaynAgent):
return "CiaynAgent"
else:
return "React"
def init_fallback_handler(agent: RAgents, tools: List[Any]):
"""
Initialize fallback handler if agent is of type "React" and experimental_fallback_handler is enabled; otherwise return None.
"""
if not get_config_repository().get("experimental_fallback_handler", False):
return None
agent_type = get_agent_type(agent)
if agent_type == "React":
return FallbackHandler(get_config_repository().get_all(), tools)
return None
def _handle_fallback_response(
error: ToolExecutionError,
fallback_handler: Optional[FallbackHandler],
agent: RAgents,
msg_list: list,
) -> None:
"""
Handle fallback response by invoking fallback_handler and updating msg_list.
"""
if not fallback_handler:
return
fallback_response = fallback_handler.handle_failure(error, agent, msg_list)
agent_type = get_agent_type(agent)
if fallback_response and agent_type == "React":
msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response]
msg_list.extend(msg_list_response)
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
"""
Streams agent output while handling completion and interruption.
For each chunk, it logs the output, calls check_interrupt(), prints agent output,
and then checks if is_completed() or should_exit() are true. If so, it resets completion
flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks),
the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty),
it resumes execution via agent.invoke(None, config); otherwise, it exits the loop.
This function adheres to the latest LangGraph best practices (as of March 2025) for handling
human-in-the-loop interruptions using interrupt_after=["tools"].
"""
config = get_config_repository().get_all()
stream_config = config.copy()
cb = None
if is_anthropic_claude(config):
model_name = config.get("model", "")
full_model_name = model_name
cb = AnthropicCallbackHandler(full_model_name)
if "callbacks" not in stream_config:
stream_config["callbacks"] = []
stream_config["callbacks"].append(cb)
while True:
for chunk in agent.stream({"messages": msg_list}, stream_config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type, cost_cb=cb)
if is_completed() or should_exit():
reset_completion_flags()
if cb:
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
return True
logger.debug("Stream iteration ended; checking agent state for continuation.")
# Prepare state configuration, ensuring 'configurable' is present.
state_config = get_config_repository().get_all().copy()
if "configurable" not in state_config:
logger.debug(
"Key 'configurable' not found in config; adding it as an empty dict."
)
state_config["configurable"] = {}
logger.debug("Using state_config for agent.get_state(): %s", state_config)
try:
state = agent.get_state(state_config)
logger.debug("Agent state retrieved: %s", state)
except Exception as e:
logger.error(
"Error retrieving agent state with state_config %s: %s", state_config, e
)
raise
if state.next:
logger.debug(
"State indicates continuation (state.next: %s); resuming execution.",
state.next,
)
agent.invoke(None, stream_config)
continue
else:
logger.debug("No continuation indicated in state; exiting stream loop.")
break
if cb:
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
return True
def run_agent_with_retry(
agent: RAgents,
prompt: str,
fallback_handler: Optional[FallbackHandler] = None,
) -> Optional[str]:
"""Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = _setup_interrupt_handling()
max_retries = 20
base_delay = 1
test_attempts = 0
_max_test_retries = get_config_repository().get(
"max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES
)
auto_test = get_config_repository().get("auto_test", False)
original_prompt = prompt
msg_list = [HumanMessage(content=prompt)]
run_config = get_config_repository().get_all()
# Create a new agent context for this run
with InterruptibleSection(), agent_context() as ctx:
try:
for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt()
# Check if the agent has crashed before attempting to run it
from ra_aid.agent_context import get_crash_message, is_crashed
if is_crashed():
crash_message = get_crash_message()
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
try:
_run_agent_stream(agent, msg_list)
if fallback_handler:
fallback_handler.reset_fallback_handler()
should_break, prompt, auto_test, test_attempts = (
_execute_test_command_wrapper(
original_prompt, run_config, test_attempts, auto_test
)
)
if should_break:
break
if prompt != original_prompt:
continue
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except ToolExecutionError as e:
# Check if this is a BadRequestError (HTTP 400) which is unretryable
error_str = str(e).lower()
if "400" in error_str or "bad request" in error_str:
from ra_aid.agent_context import mark_agent_crashed
crash_message = f"Unretryable error: {str(e)}"
mark_agent_crashed(crash_message)
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
_handle_fallback_response(e, fallback_handler, agent, msg_list)
continue
except FallbackToolExecutionError as e:
msg_list.append(
SystemMessage(f"FallbackToolExecutionError:{str(e)}")
)
except (KeyboardInterrupt, AgentInterrupt):
raise
except (
InternalServerError,
APITimeoutError,
RateLimitError,
OpenAIRateLimitError,
LiteLLMRateLimitError,
ResourceExhausted,
APIError,
ValueError,
) as e:
# Check if this is a BadRequestError (HTTP 400) which is unretryable
error_str = str(e).lower()
if (
"400" in error_str or "bad request" in error_str
) and isinstance(e, APIError):
from ra_aid.agent_context import mark_agent_crashed
crash_message = f"Unretryable API error: {str(e)}"
mark_agent_crashed(crash_message)
logger.error("Agent has crashed: %s", crash_message)
return f"Agent has crashed: {crash_message}"
_handle_api_error(e, attempt, max_retries, base_delay)
finally:
_restore_interrupt_handling(original_handler)