RA.Aid/ra_aid/agent_utils.py

754 lines
25 KiB
Python

"""Utility functions for working with agents."""
import sys
import time
import uuid
from typing import Optional, Any, List, Dict, Sequence
from langchain_core.messages import BaseMessage, trim_messages
import signal
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt.chat_agent_executor import AgentState
from ra_aid.config import DEFAULT_RECURSION_LIMIT
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens
from ra_aid.agents.ciayn_agent import CiaynAgent
import threading
from ra_aid.project_info import (
get_project_info,
format_project_info,
display_project_status,
)
from langgraph.prebuilt import create_react_agent
from ra_aid.console.formatting import print_stage_header, print_error
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import tool
from ra_aid.console.output import print_agent_output
from ra_aid.logging_config import get_logger
from ra_aid.exceptions import AgentInterrupt
from ra_aid.tool_configs import (
get_implementation_tools,
get_research_tools,
get_planning_tools,
get_web_research_tools,
)
from ra_aid.prompts import (
IMPLEMENTATION_PROMPT,
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
EXPERT_PROMPT_SECTION_RESEARCH,
WEB_RESEARCH_PROMPT_SECTION_RESEARCH,
WEB_RESEARCH_PROMPT_SECTION_CHAT,
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
RESEARCH_PROMPT,
RESEARCH_ONLY_PROMPT,
HUMAN_PROMPT_SECTION_RESEARCH,
PLANNING_PROMPT,
EXPERT_PROMPT_SECTION_PLANNING,
HUMAN_PROMPT_SECTION_PLANNING,
WEB_RESEARCH_PROMPT,
)
from langchain_core.messages import HumanMessage
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.tools.memory import (
_global_memory,
get_memory_value,
get_related_files,
)
console = Console()
logger = get_logger(__name__)
@tool
def output_markdown_message(message: str) -> str:
"""Outputs a message to the user, optionally prompting for input."""
console.print(Panel(Markdown(message.strip()), title="🤖 Assistant"))
return "Message output."
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
"""Helper function to estimate total tokens in a sequence of messages.
Args:
messages: Sequence of messages to count tokens for
Returns:
Total estimated token count
"""
if not messages:
return 0
estimate_tokens = CiaynAgent._estimate_tokens
return sum(estimate_tokens(msg) for msg in messages)
def state_modifier(
state: AgentState, max_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
Args:
state: The current agent state containing messages
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_tokens - first_tokens
trimmed_remaining = trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
)
return [first_message] + trimmed_remaining
def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
"""Get the token limit for the current model configuration.
Returns:
Optional[int]: The token limit if found, None otherwise
"""
try:
provider = config.get("provider", "")
model_name = config.get("model", "")
provider_tokens = models_tokens.get(provider, {})
token_limit = provider_tokens.get(model_name, None)
if token_limit:
logger.debug(
f"Found token limit for {provider}/{model_name}: {token_limit}"
)
else:
logger.debug(f"Could not find token limit for {provider}/{model_name}")
return token_limit
except Exception as e:
logger.warning(f"Failed to get model token limit: {e}")
return None
def build_agent_kwargs(
checkpointer: Optional[Any] = None,
config: Dict[str, Any] = None,
token_limit: Optional[int] = None,
) -> Dict[str, Any]:
"""Build kwargs dictionary for agent creation.
Args:
checkpointer: Optional memory checkpointer
config: Optional configuration dictionary
token_limit: Optional token limit for the model
Returns:
Dictionary of kwargs for agent creation
"""
agent_kwargs = {}
if checkpointer is not None:
agent_kwargs["checkpointer"] = checkpointer
if config.get("limit_tokens", True) and is_anthropic_claude(config):
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
return state_modifier(state, max_tokens=token_limit)
agent_kwargs["state_modifier"] = wrapped_state_modifier
return agent_kwargs
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
"""Check if the provider and model name indicate an Anthropic Claude model.
Args:
provider: The provider name
model_name: The model name
Returns:
bool: True if this is an Anthropic Claude model
"""
provider = config.get("provider", "")
model_name = config.get("model", "")
return (
provider.lower() == "anthropic"
and model_name
and "claude" in model_name.lower()
)
def create_agent(
model: BaseChatModel,
tools: List[Any],
*,
checkpointer: Any = None,
) -> Any:
"""Create a react agent with the given configuration.
Args:
model: The LLM model to use
tools: List of tools to provide to the agent
checkpointer: Optional memory checkpointer
config: Optional configuration dictionary containing settings like:
- limit_tokens (bool): Whether to apply token limiting (default: True)
- provider (str): The LLM provider name
- model (str): The model name
Returns:
The created agent instance
Token limiting helps prevent context window overflow by trimming older messages
while preserving system messages. It can be disabled by setting
config['limit_tokens'] = False.
"""
try:
config = _global_memory.get("config", {})
token_limit = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
return create_react_agent(model, tools, **agent_kwargs)
else:
logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools, max_tokens=token_limit)
except Exception as e:
# Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = _global_memory.get("config", {})
token_limit = get_model_token_limit(config)
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
return create_react_agent(model, tools, **agent_kwargs)
def run_research_agent(
base_task_or_query: str,
model,
*,
expert_enabled: bool = False,
research_only: bool = False,
hil: bool = False,
web_research_enabled: bool = False,
memory: Optional[Any] = None,
config: Optional[dict] = None,
thread_id: Optional[str] = None,
console_message: Optional[str] = None,
) -> Optional[str]:
"""Run a research agent with the given configuration.
Args:
base_task_or_query: The main task or query for research
model: The LLM model to use
expert_enabled: Whether expert mode is enabled
research_only: Whether this is a research-only task
hil: Whether human-in-the-loop mode is enabled
web_research_enabled: Whether web research is enabled
memory: Optional memory instance to use
config: Optional configuration dictionary
thread_id: Optional thread ID (defaults to new UUID)
console_message: Optional message to display before running
Returns:
Optional[str]: The completion message if task completed successfully
Example:
result = run_research_agent(
"Research Python async patterns",
model,
expert_enabled=True,
research_only=True
)
"""
thread_id = thread_id or str(uuid.uuid4())
logger.debug("Starting research agent with thread_id=%s", thread_id)
logger.debug(
"Research configuration: expert=%s, research_only=%s, hil=%s, web=%s",
expert_enabled,
research_only,
hil,
web_research_enabled,
)
if memory is None:
memory = MemorySaver()
if thread_id is None:
thread_id = str(uuid.uuid4())
tools = get_research_tools(
research_only=research_only,
expert_enabled=expert_enabled,
human_interaction=hil,
web_research_enabled=config.get("web_research_enabled", False),
)
agent = create_agent(model, tools, checkpointer=memory)
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
web_research_section = (
WEB_RESEARCH_PROMPT_SECTION_RESEARCH
if config.get("web_research_enabled")
else ""
)
key_facts = _global_memory.get("key_facts", "")
code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "")
try:
project_info = get_project_info(".", file_limit=2000)
formatted_project_info = format_project_info(project_info)
except Exception as e:
logger.warning(f"Failed to get project info: {e}")
formatted_project_info = ""
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
base_task=base_task_or_query,
research_only_note=""
if research_only
else " Only request implementation if the user explicitly asked for changes to be made.",
expert_section=expert_section,
human_section=human_section,
web_research_section=web_research_section,
key_facts=key_facts,
work_log=get_memory_value("work_log"),
code_snippets=code_snippets,
related_files=related_files,
project_info=formatted_project_info,
)
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
if config:
run_config.update(config)
try:
if console_message:
console.print(
Panel(Markdown(console_message), title="🔬 Looking into it...")
)
if project_info:
display_project_status(project_info)
if agent is not None:
logger.debug("Research agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
else:
logger.debug("No model provided, running web research tools directly")
return run_web_research_agent(
base_task_or_query,
model=None,
expert_enabled=expert_enabled,
hil=hil,
web_research_enabled=web_research_enabled,
memory=memory,
config=config,
thread_id=thread_id,
console_message=console_message,
)
except (KeyboardInterrupt, AgentInterrupt):
raise
except Exception as e:
logger.error("Research agent failed: %s", str(e), exc_info=True)
raise
def run_web_research_agent(
query: str,
model,
*,
expert_enabled: bool = False,
hil: bool = False,
web_research_enabled: bool = False,
memory: Optional[Any] = None,
config: Optional[dict] = None,
thread_id: Optional[str] = None,
console_message: Optional[str] = None,
) -> Optional[str]:
"""Run a web research agent with the given configuration.
Args:
query: The mainquery for web research
model: The LLM model to use
expert_enabled: Whether expert mode is enabled
hil: Whether human-in-the-loop mode is enabled
web_research_enabled: Whether web research is enabled
memory: Optional memory instance to use
config: Optional configuration dictionary
thread_id: Optional thread ID (defaults to new UUID)
console_message: Optional message to display before running
Returns:
Optional[str]: The completion message if task completed successfully
Example:
result = run_web_research_agent(
"Research latest Python async patterns",
model,
expert_enabled=True
)
"""
thread_id = thread_id or str(uuid.uuid4())
logger.debug("Starting web research agent with thread_id=%s", thread_id)
logger.debug(
"Web research configuration: expert=%s, hil=%s, web=%s",
expert_enabled,
hil,
web_research_enabled,
)
if memory is None:
memory = MemorySaver()
if thread_id is None:
thread_id = str(uuid.uuid4())
tools = get_web_research_tools(expert_enabled=expert_enabled)
agent = create_agent(model, tools, checkpointer=memory)
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
key_facts = _global_memory.get("key_facts", "")
code_snippets = _global_memory.get("code_snippets", "")
related_files = _global_memory.get("related_files", "")
prompt = WEB_RESEARCH_PROMPT.format(
web_research_query=query,
expert_section=expert_section,
human_section=human_section,
key_facts=key_facts,
code_snippets=code_snippets,
related_files=related_files,
)
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
if config:
run_config.update(config)
try:
if console_message:
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
logger.debug("Web research agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
except (KeyboardInterrupt, AgentInterrupt):
raise
except Exception as e:
logger.error("Web research agent failed: %s", str(e), exc_info=True)
raise
def run_planning_agent(
base_task: str,
model,
*,
expert_enabled: bool = False,
hil: bool = False,
memory: Optional[Any] = None,
config: Optional[dict] = None,
thread_id: Optional[str] = None,
) -> Optional[str]:
"""Run a planning agent to create implementation plans.
Args:
base_task: The main task to plan implementation for
model: The LLM model to use
expert_enabled: Whether expert mode is enabled
hil: Whether human-in-the-loop mode is enabled
memory: Optional memory instance to use
config: Optional configuration dictionary
thread_id: Optional thread ID (defaults to new UUID)
Returns:
Optional[str]: The completion message if planning completed successfully
"""
thread_id = thread_id or str(uuid.uuid4())
logger.debug("Starting planning agent with thread_id=%s", thread_id)
logger.debug("Planning configuration: expert=%s, hil=%s", expert_enabled, hil)
if memory is None:
memory = MemorySaver()
if thread_id is None:
thread_id = str(uuid.uuid4())
tools = get_planning_tools(
expert_enabled=expert_enabled,
web_research_enabled=config.get("web_research_enabled", False),
)
agent = create_agent(model, tools, checkpointer=memory)
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
web_research_section = (
WEB_RESEARCH_PROMPT_SECTION_PLANNING
if config.get("web_research_enabled")
else ""
)
planning_prompt = PLANNING_PROMPT.format(
expert_section=expert_section,
human_section=human_section,
web_research_section=web_research_section,
base_task=base_task,
research_notes=get_memory_value("research_notes"),
related_files="\n".join(get_related_files()),
key_facts=get_memory_value("key_facts"),
key_snippets=get_memory_value("key_snippets"),
work_log=get_memory_value("work_log"),
research_only_note=""
if config.get("research_only")
else " Only request implementation if the user explicitly asked for changes to be made.",
)
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
if config:
run_config.update(config)
try:
print_stage_header("Planning Stage")
logger.debug("Planning agent completed successfully")
return run_agent_with_retry(agent, planning_prompt, run_config)
except (KeyboardInterrupt, AgentInterrupt):
raise
except Exception as e:
logger.error("Planning agent failed: %s", str(e), exc_info=True)
raise
def run_task_implementation_agent(
base_task: str,
tasks: list,
task: str,
plan: str,
related_files: list,
model,
*,
expert_enabled: bool = False,
web_research_enabled: bool = False,
memory: Optional[Any] = None,
config: Optional[dict] = None,
thread_id: Optional[str] = None,
) -> Optional[str]:
"""Run an implementation agent for a specific task.
Args:
base_task: The main task being implemented
tasks: List of tasks to implement
plan: The implementation plan
related_files: List of related files
model: The LLM model to use
expert_enabled: Whether expert mode is enabled
web_research_enabled: Whether web research is enabled
memory: Optional memory instance to use
config: Optional configuration dictionary
thread_id: Optional thread ID (defaults to new UUID)
Returns:
Optional[str]: The completion message if task completed successfully
"""
thread_id = thread_id or str(uuid.uuid4())
logger.debug("Starting implementation agent with thread_id=%s", thread_id)
logger.debug(
"Implementation configuration: expert=%s, web=%s",
expert_enabled,
web_research_enabled,
)
logger.debug("Task details: base_task=%s, current_task=%s", base_task, task)
logger.debug("Related files: %s", related_files)
if memory is None:
memory = MemorySaver()
if thread_id is None:
thread_id = str(uuid.uuid4())
tools = get_implementation_tools(
expert_enabled=expert_enabled,
web_research_enabled=config.get("web_research_enabled", False),
)
agent = create_agent(model, tools, checkpointer=memory)
prompt = IMPLEMENTATION_PROMPT.format(
base_task=base_task,
task=task,
tasks=tasks,
plan=plan,
related_files=related_files,
key_facts=get_memory_value("key_facts"),
key_snippets=get_memory_value("key_snippets"),
research_notes=get_memory_value("research_notes"),
work_log=get_memory_value("work_log"),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION
if _global_memory.get("config", {}).get("hil", False)
else "",
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT
if config.get("web_research_enabled")
else "",
)
config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
if config:
run_config.update(config)
try:
logger.debug("Implementation agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
except (KeyboardInterrupt, AgentInterrupt):
raise
except Exception as e:
logger.error("Implementation agent failed: %s", str(e), exc_info=True)
raise
_CONTEXT_STACK = []
_INTERRUPT_CONTEXT = None
_FEEDBACK_MODE = False
def _request_interrupt(signum, frame):
global _INTERRUPT_CONTEXT
if _CONTEXT_STACK:
_INTERRUPT_CONTEXT = _CONTEXT_STACK[-1]
if _FEEDBACK_MODE:
print()
print(" 👋 Bye!")
print()
sys.exit(0)
class InterruptibleSection:
def __enter__(self):
_CONTEXT_STACK.append(self)
return self
def __exit__(self, exc_type, exc_value, traceback):
_CONTEXT_STACK.remove(self)
def check_interrupt():
if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]:
raise AgentInterrupt("Interrupt requested")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = None
if threading.current_thread() is threading.main_thread():
original_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, _request_interrupt)
max_retries = 20
base_delay = 1
with InterruptibleSection():
try:
# Track agent execution depth
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt()
try:
for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]}, config
):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"]:
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
if _global_memory["task_completed"]:
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
break
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except (KeyboardInterrupt, AgentInterrupt):
raise
except (
InternalServerError,
APITimeoutError,
RateLimitError,
APIError,
ValueError,
) as e:
if isinstance(e, ValueError):
error_str = str(e).lower()
if "code" not in error_str or "429" not in error_str:
raise # Re-raise ValueError if it's not a Lambda 429
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(
f"Max retries ({max_retries}) exceeded. Last error: {e}"
)
logger.warning(
"API error (attempt %d/%d): %s",
attempt + 1,
max_retries,
str(e),
)
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
finally:
# Reset depth tracking
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
if (
original_handler
and threading.current_thread() is threading.main_thread()
):
signal.signal(signal.SIGINT, original_handler)