231 lines
7.3 KiB
Python
231 lines
7.3 KiB
Python
"""Utility functions for working with agents."""
|
|
|
|
import time
|
|
import uuid
|
|
from typing import Optional, Any, List
|
|
|
|
from langgraph.prebuilt import create_react_agent
|
|
from ra_aid.tool_configs import get_implementation_tools, get_research_tools
|
|
from ra_aid.prompts import (
|
|
IMPLEMENTATION_PROMPT,
|
|
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
|
|
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
|
|
EXPERT_PROMPT_SECTION_RESEARCH,
|
|
RESEARCH_PROMPT,
|
|
HUMAN_PROMPT_SECTION_RESEARCH
|
|
)
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.messages import BaseMessage
|
|
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
|
from rich.console import Console
|
|
from rich.markdown import Markdown
|
|
from rich.panel import Panel
|
|
|
|
from ra_aid.tools.memory import (
|
|
_global_memory,
|
|
get_memory_value,
|
|
)
|
|
from ra_aid.globals import RESEARCH_AGENT_RECURSION_LIMIT
|
|
from ra_aid.tool_configs import get_research_tools
|
|
from ra_aid.prompts import (
|
|
RESEARCH_PROMPT,
|
|
EXPERT_PROMPT_SECTION_RESEARCH,
|
|
HUMAN_PROMPT_SECTION_RESEARCH
|
|
)
|
|
|
|
console = Console()
|
|
|
|
def run_research_agent(
|
|
base_task_or_query: str,
|
|
model,
|
|
*,
|
|
expert_enabled: bool = False,
|
|
research_only: bool = False,
|
|
hil: bool = False,
|
|
memory: Optional[Any] = None,
|
|
config: Optional[dict] = None,
|
|
thread_id: Optional[str] = None,
|
|
console_message: Optional[str] = None
|
|
) -> Optional[str]:
|
|
"""Run a research agent with the given configuration.
|
|
|
|
Args:
|
|
base_task_or_query: The main task or query for research
|
|
model: The LLM model to use
|
|
expert_enabled: Whether expert mode is enabled
|
|
research_only: Whether this is a research-only task
|
|
hil: Whether human-in-the-loop mode is enabled
|
|
memory: Optional memory instance to use
|
|
config: Optional configuration dictionary
|
|
thread_id: Optional thread ID (defaults to new UUID)
|
|
console_message: Optional message to display before running
|
|
|
|
Returns:
|
|
Optional[str]: The completion message if task completed successfully
|
|
|
|
Example:
|
|
result = run_research_agent(
|
|
"Research Python async patterns",
|
|
model,
|
|
expert_enabled=True,
|
|
research_only=True
|
|
)
|
|
"""
|
|
# Initialize memory if not provided
|
|
if memory is None:
|
|
memory = MemorySaver()
|
|
|
|
# Set up thread ID
|
|
if thread_id is None:
|
|
thread_id = str(uuid.uuid4())
|
|
|
|
# Configure tools
|
|
tools = get_research_tools(
|
|
research_only=research_only,
|
|
expert_enabled=expert_enabled,
|
|
human_interaction=hil
|
|
)
|
|
|
|
# Create agent
|
|
agent = create_react_agent(model, tools, checkpointer=memory)
|
|
|
|
# Format prompt sections
|
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
|
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
|
|
|
# Build prompt
|
|
prompt = RESEARCH_PROMPT.format(
|
|
base_task=base_task_or_query,
|
|
research_only_note='' if research_only else ' Only request implementation if the user explicitly asked for changes to be made.',
|
|
expert_section=expert_section,
|
|
human_section=human_section
|
|
)
|
|
|
|
# Set up configuration
|
|
run_config = {
|
|
"configurable": {"thread_id": thread_id},
|
|
"recursion_limit": 100
|
|
}
|
|
if config:
|
|
run_config.update(config)
|
|
|
|
# Display console message if provided
|
|
if console_message:
|
|
console.print(Panel(Markdown(console_message), title="🔬 Research Task"))
|
|
|
|
# Run agent with retry logic
|
|
return run_agent_with_retry(agent, prompt, run_config)
|
|
|
|
def print_agent_output(chunk: dict[str, BaseMessage]) -> None:
|
|
"""Print agent output chunks."""
|
|
if chunk.get("delta") and chunk["delta"].content:
|
|
console.print(chunk["delta"].content, end="", style="blue")
|
|
|
|
def print_error(msg: str) -> None:
|
|
"""Print error messages."""
|
|
console.print(f"\n{msg}", style="red")
|
|
|
|
def run_task_implementation_agent(
|
|
base_task: str,
|
|
tasks: list,
|
|
task: str,
|
|
plan: str,
|
|
related_files: list,
|
|
model,
|
|
*,
|
|
expert_enabled: bool = False,
|
|
memory: Optional[Any] = None,
|
|
config: Optional[dict] = None,
|
|
thread_id: Optional[str] = None
|
|
) -> Optional[str]:
|
|
"""Run an implementation agent for a specific task.
|
|
|
|
Args:
|
|
base_task: The main task being implemented
|
|
tasks: List of tasks to implement
|
|
plan: The implementation plan
|
|
related_files: List of related files
|
|
model: The LLM model to use
|
|
expert_enabled: Whether expert mode is enabled
|
|
memory: Optional memory instance to use
|
|
config: Optional configuration dictionary
|
|
thread_id: Optional thread ID (defaults to new UUID)
|
|
|
|
Returns:
|
|
Optional[str]: The completion message if task completed successfully
|
|
"""
|
|
# Initialize memory if not provided
|
|
if memory is None:
|
|
memory = MemorySaver()
|
|
|
|
# Set up thread ID
|
|
if thread_id is None:
|
|
thread_id = str(uuid.uuid4())
|
|
|
|
# Configure tools
|
|
tools = get_implementation_tools(expert_enabled=expert_enabled)
|
|
|
|
# Create agent
|
|
agent = create_react_agent(model, tools, checkpointer=memory)
|
|
|
|
# Build prompt
|
|
prompt = IMPLEMENTATION_PROMPT.format(
|
|
base_task=base_task,
|
|
task=task,
|
|
tasks=tasks,
|
|
plan=plan,
|
|
related_files=related_files,
|
|
key_facts=get_memory_value('key_facts'),
|
|
key_snippets=get_memory_value('key_snippets'),
|
|
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
|
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION if _global_memory.get('config', {}).get('hil', False) else ""
|
|
)
|
|
|
|
# Set up configuration
|
|
run_config = {
|
|
"configurable": {"thread_id": thread_id},
|
|
"recursion_limit": 100
|
|
}
|
|
if config:
|
|
run_config.update(config)
|
|
|
|
# Run agent with retry logic
|
|
return run_agent_with_retry(agent, prompt, run_config)
|
|
|
|
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
|
"""Run an agent with retry logic for internal server errors and task completion handling.
|
|
|
|
Args:
|
|
agent: The agent to run
|
|
prompt: The prompt to send to the agent
|
|
config: Configuration dictionary for the agent
|
|
|
|
Returns:
|
|
Optional[str]: The completion message if task was completed, None otherwise
|
|
|
|
Handles API errors with exponential backoff retry logic and checks for task
|
|
completion after each chunk of output.
|
|
"""
|
|
max_retries = 20
|
|
base_delay = 1 # Initial delay in seconds
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
for chunk in agent.stream(
|
|
{"messages": [HumanMessage(content=prompt)]},
|
|
config
|
|
):
|
|
print_agent_output(chunk)
|
|
break
|
|
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
|
if attempt == max_retries - 1:
|
|
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
|
|
|
|
delay = base_delay * (2 ** attempt) # Exponential backoff
|
|
error_type = e.__class__.__name__
|
|
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
|
|
time.sleep(delay)
|
|
continue
|