Allow chat agent to spawn research agent.

This commit is contained in:
AI Christianson 2024-12-21 10:58:08 -05:00
parent f712fcf9c8
commit 1fc64c5151
7 changed files with 177 additions and 66 deletions

View File

@ -2,6 +2,7 @@ from .__version__ import __version__
from .console.formatting import print_stage_header, print_task_header, print_error from .console.formatting import print_stage_header, print_task_header, print_error
from .console.output import print_agent_output from .console.output import print_agent_output
from .text.processing import truncate_output from .text.processing import truncate_output
from .agent_utils import run_agent_with_retry
__all__ = [ __all__ = [
'print_stage_header', 'print_stage_header',
@ -9,5 +10,6 @@ __all__ = [
'print_agent_output', 'print_agent_output',
'truncate_output', 'truncate_output',
'print_error', 'print_error',
'run_agent_with_retry',
'__version__' '__version__'
] ]

View File

@ -1,15 +1,12 @@
import argparse import argparse
import sys import sys
from typing import Optional
from rich.panel import Panel from rich.panel import Panel
from rich.markdown import Markdown
from rich.console import Console from rich.console import Console
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from ra_aid.env import validate_environment from ra_aid.env import validate_environment
from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value
from ra_aid import print_agent_output, print_stage_header, print_task_header, print_error from ra_aid import print_stage_header, print_task_header, print_error, run_agent_with_retry
from ra_aid.prompts import ( from ra_aid.prompts import (
RESEARCH_PROMPT, RESEARCH_PROMPT,
PLANNING_PROMPT, PLANNING_PROMPT,
@ -22,8 +19,6 @@ from ra_aid.prompts import (
HUMAN_PROMPT_SECTION_PLANNING, HUMAN_PROMPT_SECTION_PLANNING,
HUMAN_PROMPT_SECTION_IMPLEMENTATION HUMAN_PROMPT_SECTION_IMPLEMENTATION
) )
import time
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from ra_aid.llm import initialize_llm from ra_aid.llm import initialize_llm
from ra_aid.tool_configs import ( from ra_aid.tool_configs import (
@ -132,51 +127,6 @@ def is_stage_requested(stage: str) -> bool:
return _global_memory.get('implementation_requested', False) return _global_memory.get('implementation_requested', False)
return False return False
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for internal server errors and task completion handling.
Args:
agent: The agent to run
prompt: The prompt to send to the agent
config: Configuration dictionary for the agent
Returns:
Optional[str]: The completion message if task was completed, None otherwise
Handles API errors with exponential backoff retry logic and checks for task
completion after each chunk of output.
"""
max_retries = 20
base_delay = 1 # Initial delay in seconds
for attempt in range(max_retries):
try:
for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]},
config
):
print_agent_output(chunk)
# Check for task completion after each chunk
if _global_memory.get('task_completed'):
completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.')
console.print(Panel(
Markdown(completion_msg),
title="✅ Task Completed",
style="green"
))
return completion_msg
break
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
delay = base_delay * (2 ** attempt) # Exponential backoff
error_type = e.__class__.__name__
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
time.sleep(delay)
continue
def run_implementation_stage(base_task, tasks, plan, related_files, model, expert_enabled: bool): def run_implementation_stage(base_task, tasks, plan, related_files, model, expert_enabled: bool):
"""Run implementation stage with a distinct agent for each task.""" """Run implementation stage with a distinct agent for each task."""
if not is_stage_requested('implementation'): if not is_stage_requested('implementation'):
@ -216,7 +166,6 @@ def run_implementation_stage(base_task, tasks, plan, related_files, model, exper
# Run agent for this task # Run agent for this task
run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100}) run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100})
def run_research_subtasks(base_task: str, config: dict, model, expert_enabled: bool): def run_research_subtasks(base_task: str, config: dict, model, expert_enabled: bool):
"""Run research subtasks with separate agents.""" """Run research subtasks with separate agents."""
subtasks = _global_memory.get('research_subtasks', []) subtasks = _global_memory.get('research_subtasks', [])

69
ra_aid/agent_utils.py Normal file
View File

@ -0,0 +1,69 @@
"""Utility functions for working with agents."""
import time
from typing import Optional
from langchain_core.messages import HumanMessage
from langchain_core.messages import BaseMessage
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.tools.memory import _global_memory
console = Console()
def print_agent_output(chunk: dict[str, BaseMessage]) -> None:
"""Print agent output chunks."""
if chunk.get("delta") and chunk["delta"].content:
console.print(chunk["delta"].content, end="", style="blue")
def print_error(msg: str) -> None:
"""Print error messages."""
console.print(f"\n{msg}", style="red")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for internal server errors and task completion handling.
Args:
agent: The agent to run
prompt: The prompt to send to the agent
config: Configuration dictionary for the agent
Returns:
Optional[str]: The completion message if task was completed, None otherwise
Handles API errors with exponential backoff retry logic and checks for task
completion after each chunk of output.
"""
max_retries = 20
base_delay = 1 # Initial delay in seconds
for attempt in range(max_retries):
try:
for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]},
config
):
print_agent_output(chunk)
# Check for task completion after each chunk
if _global_memory.get('task_completed'):
completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.')
console.print(Panel(
Markdown(completion_msg),
title="✅ Task Completed",
style="green"
))
return completion_msg
break
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
if attempt == max_retries - 1:
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
delay = base_delay * (2 ** attempt) # Exponential backoff
error_type = e.__class__.__name__
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
time.sleep(delay)
continue

View File

@ -23,27 +23,23 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
return ChatOpenAI( return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
model=model_name, model=model_name,
temperature=0
) )
elif provider == "anthropic": elif provider == "anthropic":
return ChatAnthropic( return ChatAnthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"), api_key=os.getenv("ANTHROPIC_API_KEY"),
model_name=model_name, model_name=model_name,
temperature=0
) )
elif provider == "openrouter": elif provider == "openrouter":
return ChatOpenAI( return ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"), api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model=model_name, model=model_name,
temperature=0
) )
elif provider == "openai-compatible": elif provider == "openai-compatible":
return ChatOpenAI( return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_API_BASE"), base_url=os.getenv("OPENAI_API_BASE"),
model=model_name, model=model_name,
temperature=0
) )
else: else:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")

View File

@ -9,6 +9,7 @@ from ra_aid.tools import (
swap_task_order, monorepo_detected, existing_project_detected, ui_detected swap_task_order, monorepo_detected, existing_project_detected, ui_detected
) )
from ra_aid.tools.memory import one_shot_completed from ra_aid.tools.memory import one_shot_completed
from ra_aid.tools.agent import request_research
# Read-only tools that don't modify system state # Read-only tools that don't modify system state
def get_read_only_tools(human_interaction: bool = False) -> list: def get_read_only_tools(human_interaction: bool = False) -> list:
@ -61,6 +62,9 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True,
if expert_enabled: if expert_enabled:
tools.extend(EXPERT_TOOLS) tools.extend(EXPERT_TOOLS)
# Add chat-specific tools
tools.append(request_research)
return tools return tools
def get_planning_tools(expert_enabled: bool = True) -> list: def get_planning_tools(expert_enabled: bool = True) -> list:
@ -103,18 +107,13 @@ def get_chat_tools(expert_enabled: bool = True) -> list:
Chat mode includes research and implementation capabilities but excludes Chat mode includes research and implementation capabilities but excludes
complex planning tools. Human interaction is always enabled. complex planning tools. Human interaction is always enabled.
""" """
# Start with read-only tools and always include human interaction tools = [
tools = get_read_only_tools(human_interaction=True).copy() ask_human,
request_research
# Add implementation capability ]
tools.extend(MODIFICATION_TOOLS)
# Add research tools except for subtask management
research_tools = [t for t in RESEARCH_TOOLS if t.name != 'request_research_subtask']
tools.extend(research_tools)
# Add expert tools if enabled # Add expert tools if enabled
if expert_enabled: if expert_enabled:
tools.extend(EXPERT_TOOLS) tools.extend(EXPERT_TOOLS)
return tools return tools

95
ra_aid/tools/agent.py Normal file
View File

@ -0,0 +1,95 @@
"""Tools for spawning and managing sub-agents."""
from langchain_core.tools import tool
from typing import Dict, Any, List, Optional
import uuid
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
from ra_aid.tools.memory import _global_memory
from ra_aid import run_agent_with_retry
from ..prompts import RESEARCH_PROMPT
from .memory import get_memory_value, get_related_files
from ..llm import initialize_llm
console = Console()
@tool("request_research")
def request_research(query: str) -> Dict[str, Any]:
"""Spawn a research-only agent to investigate the given query.
Args:
query: The research question or project description
Returns:
Dict containing:
- notes: Research notes from the agent
- facts: Current key facts
- files: Related files
- success: Whether completed or interrupted
"""
# Initialize model and memory
model = initialize_llm("anthropic", "claude-3-sonnet-20240229")
memory = MemorySaver()
memory.memory = _global_memory
# Configure research tools
from ..tool_configs import get_research_tools
tools = get_research_tools(research_only=True, expert_enabled=True)
# Basic config matching main process
config = {
"thread_id": str(uuid.uuid4()),
"memory": memory,
"model": model
}
from ra_aid.prompts import (
RESEARCH_PROMPT,
EXPERT_PROMPT_SECTION_RESEARCH,
HUMAN_PROMPT_SECTION_RESEARCH
)
# Create research agent
config = _global_memory.get('config', {})
expert_enabled = config.get('expert_enabled', False)
hil = config.get('hil', False)
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
agent = create_react_agent(model, tools)
prompt = RESEARCH_PROMPT.format(
base_task=query,
research_only_note='',
expert_section=expert_section,
human_section=human_section
)
try:
console.print(Panel(Markdown(query), title="🔬 Research Task"))
# Run agent with retry logic
result = run_agent_with_retry(
agent,
prompt,
{"configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": 100}
)
success = True
except KeyboardInterrupt:
console.print("\n[yellow]Research interrupted by user[/yellow]")
success = False
except Exception as e:
console.print(f"\n[red]Error during research: {str(e)}[/red]")
success = False
# Gather results
return {
"facts": get_memory_value("key_facts"),
"files": list(get_related_files()),
"notes": get_memory_value("research_notes"),
"success": success
}

View File

@ -84,6 +84,7 @@ def request_research_subtask(subtask: str) -> str:
"""Spawn a research subtask for investigation of a specific topic. """Spawn a research subtask for investigation of a specific topic.
Use this anytime you can to offload your work to specific things that need to be looked into. Use this anytime you can to offload your work to specific things that need to be looked into.
Use this only when it's necessary to dig deeper into a specific topic.
Args: Args:
subtask: Detailed description of the research subtask subtask: Detailed description of the research subtask