Allow chat agent to spawn research agent.

This commit is contained in:
AI Christianson 2024-12-21 10:58:08 -05:00
parent f712fcf9c8
commit 1fc64c5151
7 changed files with 177 additions and 66 deletions

View File

@ -2,6 +2,7 @@ from .__version__ import __version__
from .console.formatting import print_stage_header, print_task_header, print_error
from .console.output import print_agent_output
from .text.processing import truncate_output
from .agent_utils import run_agent_with_retry
__all__ = [
'print_stage_header',
@ -9,5 +10,6 @@ __all__ = [
'print_agent_output',
'truncate_output',
'print_error',
'run_agent_with_retry',
'__version__'
]

View File

@ -1,15 +1,12 @@
import argparse
import sys
from typing import Optional
from rich.panel import Panel
from rich.markdown import Markdown
from rich.console import Console
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from ra_aid.env import validate_environment
from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value
from ra_aid import print_agent_output, print_stage_header, print_task_header, print_error
from ra_aid import print_stage_header, print_task_header, print_error, run_agent_with_retry
from ra_aid.prompts import (
RESEARCH_PROMPT,
PLANNING_PROMPT,
@ -22,8 +19,6 @@ from ra_aid.prompts import (
HUMAN_PROMPT_SECTION_PLANNING,
HUMAN_PROMPT_SECTION_IMPLEMENTATION
)
import time
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from ra_aid.llm import initialize_llm
from ra_aid.tool_configs import (
@ -132,51 +127,6 @@ def is_stage_requested(stage: str) -> bool:
return _global_memory.get('implementation_requested', False)
return False
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for internal server errors and task completion handling.
Args:
agent: The agent to run
prompt: The prompt to send to the agent
config: Configuration dictionary for the agent
Returns:
Optional[str]: The completion message if task was completed, None otherwise
Handles API errors with exponential backoff retry logic and checks for task
completion after each chunk of output.
"""
max_retries = 20
base_delay = 1 # Initial delay in seconds
for attempt in range(max_retries):
try:
for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]},
config
):
print_agent_output(chunk)
# Check for task completion after each chunk
if _global_memory.get('task_completed'):
completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.')
console.print(Panel(
Markdown(completion_msg),
title="✅ Task Completed",
style="green"
))
return completion_msg
break
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
delay = base_delay * (2 ** attempt) # Exponential backoff
error_type = e.__class__.__name__
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
time.sleep(delay)
continue
def run_implementation_stage(base_task, tasks, plan, related_files, model, expert_enabled: bool):
"""Run implementation stage with a distinct agent for each task."""
if not is_stage_requested('implementation'):
@ -216,7 +166,6 @@ def run_implementation_stage(base_task, tasks, plan, related_files, model, exper
# Run agent for this task
run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100})
def run_research_subtasks(base_task: str, config: dict, model, expert_enabled: bool):
"""Run research subtasks with separate agents."""
subtasks = _global_memory.get('research_subtasks', [])

69
ra_aid/agent_utils.py Normal file
View File

@ -0,0 +1,69 @@
"""Utility functions for working with agents."""
import time
from typing import Optional
from langchain_core.messages import HumanMessage
from langchain_core.messages import BaseMessage
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.tools.memory import _global_memory
console = Console()
def print_agent_output(chunk: dict[str, BaseMessage]) -> None:
"""Print agent output chunks."""
if chunk.get("delta") and chunk["delta"].content:
console.print(chunk["delta"].content, end="", style="blue")
def print_error(msg: str) -> None:
"""Print error messages."""
console.print(f"\n{msg}", style="red")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for internal server errors and task completion handling.
Args:
agent: The agent to run
prompt: The prompt to send to the agent
config: Configuration dictionary for the agent
Returns:
Optional[str]: The completion message if task was completed, None otherwise
Handles API errors with exponential backoff retry logic and checks for task
completion after each chunk of output.
"""
max_retries = 20
base_delay = 1 # Initial delay in seconds
for attempt in range(max_retries):
try:
for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]},
config
):
print_agent_output(chunk)
# Check for task completion after each chunk
if _global_memory.get('task_completed'):
completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.')
console.print(Panel(
Markdown(completion_msg),
title="✅ Task Completed",
style="green"
))
return completion_msg
break
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
delay = base_delay * (2 ** attempt) # Exponential backoff
error_type = e.__class__.__name__
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
time.sleep(delay)
continue

View File

@ -23,27 +23,23 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
model=model_name,
temperature=0
)
elif provider == "anthropic":
return ChatAnthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"),
model_name=model_name,
temperature=0
)
elif provider == "openrouter":
return ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model=model_name,
temperature=0
)
elif provider == "openai-compatible":
return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_API_BASE"),
model=model_name,
temperature=0
)
else:
raise ValueError(f"Unsupported provider: {provider}")

View File

@ -9,6 +9,7 @@ from ra_aid.tools import (
swap_task_order, monorepo_detected, existing_project_detected, ui_detected
)
from ra_aid.tools.memory import one_shot_completed
from ra_aid.tools.agent import request_research
# Read-only tools that don't modify system state
def get_read_only_tools(human_interaction: bool = False) -> list:
@ -61,6 +62,9 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True,
if expert_enabled:
tools.extend(EXPERT_TOOLS)
# Add chat-specific tools
tools.append(request_research)
return tools
def get_planning_tools(expert_enabled: bool = True) -> list:
@ -103,15 +107,10 @@ def get_chat_tools(expert_enabled: bool = True) -> list:
Chat mode includes research and implementation capabilities but excludes
complex planning tools. Human interaction is always enabled.
"""
# Start with read-only tools and always include human interaction
tools = get_read_only_tools(human_interaction=True).copy()
# Add implementation capability
tools.extend(MODIFICATION_TOOLS)
# Add research tools except for subtask management
research_tools = [t for t in RESEARCH_TOOLS if t.name != 'request_research_subtask']
tools.extend(research_tools)
tools = [
ask_human,
request_research
]
# Add expert tools if enabled
if expert_enabled:

95
ra_aid/tools/agent.py Normal file
View File

@ -0,0 +1,95 @@
"""Tools for spawning and managing sub-agents."""
from langchain_core.tools import tool
from typing import Dict, Any, List, Optional
import uuid
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from ra_aid.tools.memory import _global_memory
from ra_aid import run_agent_with_retry
from ..prompts import RESEARCH_PROMPT
from .memory import get_memory_value, get_related_files
from ..llm import initialize_llm
console = Console()
@tool("request_research")
def request_research(query: str) -> Dict[str, Any]:
"""Spawn a research-only agent to investigate the given query.
Args:
query: The research question or project description
Returns:
Dict containing:
- notes: Research notes from the agent
- facts: Current key facts
- files: Related files
- success: Whether completed or interrupted
"""
# Initialize model and memory
model = initialize_llm("anthropic", "claude-3-sonnet-20240229")
memory = MemorySaver()
memory.memory = _global_memory
# Configure research tools
from ..tool_configs import get_research_tools
tools = get_research_tools(research_only=True, expert_enabled=True)
# Basic config matching main process
config = {
"thread_id": str(uuid.uuid4()),
"memory": memory,
"model": model
}
from ra_aid.prompts import (
RESEARCH_PROMPT,
EXPERT_PROMPT_SECTION_RESEARCH,
HUMAN_PROMPT_SECTION_RESEARCH
)
# Create research agent
config = _global_memory.get('config', {})
expert_enabled = config.get('expert_enabled', False)
hil = config.get('hil', False)
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
agent = create_react_agent(model, tools)
prompt = RESEARCH_PROMPT.format(
base_task=query,
research_only_note='',
expert_section=expert_section,
human_section=human_section
)
try:
console.print(Panel(Markdown(query), title="🔬 Research Task"))
# Run agent with retry logic
result = run_agent_with_retry(
agent,
prompt,
{"configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": 100}
)
success = True
except KeyboardInterrupt:
console.print("\n[yellow]Research interrupted by user[/yellow]")
success = False
except Exception as e:
console.print(f"\n[red]Error during research: {str(e)}[/red]")
success = False
# Gather results
return {
"facts": get_memory_value("key_facts"),
"files": list(get_related_files()),
"notes": get_memory_value("research_notes"),
"success": success
}

View File

@ -84,6 +84,7 @@ def request_research_subtask(subtask: str) -> str:
"""Spawn a research subtask for investigation of a specific topic.
Use this anytime you can to offload your work to specific things that need to be looked into.
Use this only when it's necessary to dig deeper into a specific topic.
Args:
subtask: Detailed description of the research subtask