Allow chat agent to spawn research agent.
This commit is contained in:
parent
f712fcf9c8
commit
1fc64c5151
|
|
@ -2,6 +2,7 @@ from .__version__ import __version__
|
||||||
from .console.formatting import print_stage_header, print_task_header, print_error
|
from .console.formatting import print_stage_header, print_task_header, print_error
|
||||||
from .console.output import print_agent_output
|
from .console.output import print_agent_output
|
||||||
from .text.processing import truncate_output
|
from .text.processing import truncate_output
|
||||||
|
from .agent_utils import run_agent_with_retry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'print_stage_header',
|
'print_stage_header',
|
||||||
|
|
@ -9,5 +10,6 @@ __all__ = [
|
||||||
'print_agent_output',
|
'print_agent_output',
|
||||||
'truncate_output',
|
'truncate_output',
|
||||||
'print_error',
|
'print_error',
|
||||||
|
'run_agent_with_retry',
|
||||||
'__version__'
|
'__version__'
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,12 @@
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional
|
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.markdown import Markdown
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value
|
from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value
|
||||||
from ra_aid import print_agent_output, print_stage_header, print_task_header, print_error
|
from ra_aid import print_stage_header, print_task_header, print_error, run_agent_with_retry
|
||||||
from ra_aid.prompts import (
|
from ra_aid.prompts import (
|
||||||
RESEARCH_PROMPT,
|
RESEARCH_PROMPT,
|
||||||
PLANNING_PROMPT,
|
PLANNING_PROMPT,
|
||||||
|
|
@ -22,8 +19,6 @@ from ra_aid.prompts import (
|
||||||
HUMAN_PROMPT_SECTION_PLANNING,
|
HUMAN_PROMPT_SECTION_PLANNING,
|
||||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
||||||
)
|
)
|
||||||
import time
|
|
||||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
|
|
||||||
from ra_aid.tool_configs import (
|
from ra_aid.tool_configs import (
|
||||||
|
|
@ -132,51 +127,6 @@ def is_stage_requested(stage: str) -> bool:
|
||||||
return _global_memory.get('implementation_requested', False)
|
return _global_memory.get('implementation_requested', False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
|
||||||
"""Run an agent with retry logic for internal server errors and task completion handling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent: The agent to run
|
|
||||||
prompt: The prompt to send to the agent
|
|
||||||
config: Configuration dictionary for the agent
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[str]: The completion message if task was completed, None otherwise
|
|
||||||
|
|
||||||
Handles API errors with exponential backoff retry logic and checks for task
|
|
||||||
completion after each chunk of output.
|
|
||||||
"""
|
|
||||||
max_retries = 20
|
|
||||||
base_delay = 1 # Initial delay in seconds
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
for chunk in agent.stream(
|
|
||||||
{"messages": [HumanMessage(content=prompt)]},
|
|
||||||
config
|
|
||||||
):
|
|
||||||
print_agent_output(chunk)
|
|
||||||
|
|
||||||
# Check for task completion after each chunk
|
|
||||||
if _global_memory.get('task_completed'):
|
|
||||||
completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.')
|
|
||||||
console.print(Panel(
|
|
||||||
Markdown(completion_msg),
|
|
||||||
title="✅ Task Completed",
|
|
||||||
style="green"
|
|
||||||
))
|
|
||||||
return completion_msg
|
|
||||||
break
|
|
||||||
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
|
||||||
if attempt == max_retries - 1:
|
|
||||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
|
|
||||||
|
|
||||||
delay = base_delay * (2 ** attempt) # Exponential backoff
|
|
||||||
error_type = e.__class__.__name__
|
|
||||||
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
|
|
||||||
time.sleep(delay)
|
|
||||||
continue
|
|
||||||
|
|
||||||
def run_implementation_stage(base_task, tasks, plan, related_files, model, expert_enabled: bool):
|
def run_implementation_stage(base_task, tasks, plan, related_files, model, expert_enabled: bool):
|
||||||
"""Run implementation stage with a distinct agent for each task."""
|
"""Run implementation stage with a distinct agent for each task."""
|
||||||
if not is_stage_requested('implementation'):
|
if not is_stage_requested('implementation'):
|
||||||
|
|
@ -216,7 +166,6 @@ def run_implementation_stage(base_task, tasks, plan, related_files, model, exper
|
||||||
# Run agent for this task
|
# Run agent for this task
|
||||||
run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100})
|
run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100})
|
||||||
|
|
||||||
|
|
||||||
def run_research_subtasks(base_task: str, config: dict, model, expert_enabled: bool):
|
def run_research_subtasks(base_task: str, config: dict, model, expert_enabled: bool):
|
||||||
"""Run research subtasks with separate agents."""
|
"""Run research subtasks with separate agents."""
|
||||||
subtasks = _global_memory.get('research_subtasks', [])
|
subtasks = _global_memory.get('research_subtasks', [])
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Utility functions for working with agents."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
def print_agent_output(chunk: dict[str, BaseMessage]) -> None:
|
||||||
|
"""Print agent output chunks."""
|
||||||
|
if chunk.get("delta") and chunk["delta"].content:
|
||||||
|
console.print(chunk["delta"].content, end="", style="blue")
|
||||||
|
|
||||||
|
def print_error(msg: str) -> None:
|
||||||
|
"""Print error messages."""
|
||||||
|
console.print(f"\n{msg}", style="red")
|
||||||
|
|
||||||
|
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||||
|
"""Run an agent with retry logic for internal server errors and task completion handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent to run
|
||||||
|
prompt: The prompt to send to the agent
|
||||||
|
config: Configuration dictionary for the agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The completion message if task was completed, None otherwise
|
||||||
|
|
||||||
|
Handles API errors with exponential backoff retry logic and checks for task
|
||||||
|
completion after each chunk of output.
|
||||||
|
"""
|
||||||
|
max_retries = 20
|
||||||
|
base_delay = 1 # Initial delay in seconds
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
for chunk in agent.stream(
|
||||||
|
{"messages": [HumanMessage(content=prompt)]},
|
||||||
|
config
|
||||||
|
):
|
||||||
|
print_agent_output(chunk)
|
||||||
|
|
||||||
|
# Check for task completion after each chunk
|
||||||
|
if _global_memory.get('task_completed'):
|
||||||
|
completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.')
|
||||||
|
console.print(Panel(
|
||||||
|
Markdown(completion_msg),
|
||||||
|
title="✅ Task Completed",
|
||||||
|
style="green"
|
||||||
|
))
|
||||||
|
return completion_msg
|
||||||
|
break
|
||||||
|
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
|
||||||
|
|
||||||
|
delay = base_delay * (2 ** attempt) # Exponential backoff
|
||||||
|
error_type = e.__class__.__name__
|
||||||
|
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
|
@ -23,27 +23,23 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
model=model_name,
|
model=model_name,
|
||||||
temperature=0
|
|
||||||
)
|
)
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
return ChatAnthropic(
|
return ChatAnthropic(
|
||||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=0
|
|
||||||
)
|
)
|
||||||
elif provider == "openrouter":
|
elif provider == "openrouter":
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||||
base_url="https://openrouter.ai/api/v1",
|
base_url="https://openrouter.ai/api/v1",
|
||||||
model=model_name,
|
model=model_name,
|
||||||
temperature=0
|
|
||||||
)
|
)
|
||||||
elif provider == "openai-compatible":
|
elif provider == "openai-compatible":
|
||||||
return ChatOpenAI(
|
return ChatOpenAI(
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
base_url=os.getenv("OPENAI_API_BASE"),
|
base_url=os.getenv("OPENAI_API_BASE"),
|
||||||
model=model_name,
|
model=model_name,
|
||||||
temperature=0
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from ra_aid.tools import (
|
||||||
swap_task_order, monorepo_detected, existing_project_detected, ui_detected
|
swap_task_order, monorepo_detected, existing_project_detected, ui_detected
|
||||||
)
|
)
|
||||||
from ra_aid.tools.memory import one_shot_completed
|
from ra_aid.tools.memory import one_shot_completed
|
||||||
|
from ra_aid.tools.agent import request_research
|
||||||
|
|
||||||
# Read-only tools that don't modify system state
|
# Read-only tools that don't modify system state
|
||||||
def get_read_only_tools(human_interaction: bool = False) -> list:
|
def get_read_only_tools(human_interaction: bool = False) -> list:
|
||||||
|
|
@ -61,6 +62,9 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True,
|
||||||
if expert_enabled:
|
if expert_enabled:
|
||||||
tools.extend(EXPERT_TOOLS)
|
tools.extend(EXPERT_TOOLS)
|
||||||
|
|
||||||
|
# Add chat-specific tools
|
||||||
|
tools.append(request_research)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def get_planning_tools(expert_enabled: bool = True) -> list:
|
def get_planning_tools(expert_enabled: bool = True) -> list:
|
||||||
|
|
@ -103,18 +107,13 @@ def get_chat_tools(expert_enabled: bool = True) -> list:
|
||||||
Chat mode includes research and implementation capabilities but excludes
|
Chat mode includes research and implementation capabilities but excludes
|
||||||
complex planning tools. Human interaction is always enabled.
|
complex planning tools. Human interaction is always enabled.
|
||||||
"""
|
"""
|
||||||
# Start with read-only tools and always include human interaction
|
tools = [
|
||||||
tools = get_read_only_tools(human_interaction=True).copy()
|
ask_human,
|
||||||
|
request_research
|
||||||
# Add implementation capability
|
]
|
||||||
tools.extend(MODIFICATION_TOOLS)
|
|
||||||
|
|
||||||
# Add research tools except for subtask management
|
|
||||||
research_tools = [t for t in RESEARCH_TOOLS if t.name != 'request_research_subtask']
|
|
||||||
tools.extend(research_tools)
|
|
||||||
|
|
||||||
# Add expert tools if enabled
|
# Add expert tools if enabled
|
||||||
if expert_enabled:
|
if expert_enabled:
|
||||||
tools.extend(EXPERT_TOOLS)
|
tools.extend(EXPERT_TOOLS)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""Tools for spawning and managing sub-agents."""
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
import uuid
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
from ra_aid import run_agent_with_retry
|
||||||
|
from ..prompts import RESEARCH_PROMPT
|
||||||
|
from .memory import get_memory_value, get_related_files
|
||||||
|
from ..llm import initialize_llm
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
@tool("request_research")
|
||||||
|
def request_research(query: str) -> Dict[str, Any]:
|
||||||
|
"""Spawn a research-only agent to investigate the given query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The research question or project description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing:
|
||||||
|
- notes: Research notes from the agent
|
||||||
|
- facts: Current key facts
|
||||||
|
- files: Related files
|
||||||
|
- success: Whether completed or interrupted
|
||||||
|
"""
|
||||||
|
# Initialize model and memory
|
||||||
|
model = initialize_llm("anthropic", "claude-3-sonnet-20240229")
|
||||||
|
memory = MemorySaver()
|
||||||
|
memory.memory = _global_memory
|
||||||
|
|
||||||
|
# Configure research tools
|
||||||
|
from ..tool_configs import get_research_tools
|
||||||
|
tools = get_research_tools(research_only=True, expert_enabled=True)
|
||||||
|
|
||||||
|
# Basic config matching main process
|
||||||
|
config = {
|
||||||
|
"thread_id": str(uuid.uuid4()),
|
||||||
|
"memory": memory,
|
||||||
|
"model": model
|
||||||
|
}
|
||||||
|
|
||||||
|
from ra_aid.prompts import (
|
||||||
|
RESEARCH_PROMPT,
|
||||||
|
EXPERT_PROMPT_SECTION_RESEARCH,
|
||||||
|
HUMAN_PROMPT_SECTION_RESEARCH
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create research agent
|
||||||
|
config = _global_memory.get('config', {})
|
||||||
|
expert_enabled = config.get('expert_enabled', False)
|
||||||
|
hil = config.get('hil', False)
|
||||||
|
|
||||||
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
|
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||||
|
|
||||||
|
agent = create_react_agent(model, tools)
|
||||||
|
|
||||||
|
prompt = RESEARCH_PROMPT.format(
|
||||||
|
base_task=query,
|
||||||
|
research_only_note='',
|
||||||
|
expert_section=expert_section,
|
||||||
|
human_section=human_section
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
console.print(Panel(Markdown(query), title="🔬 Research Task"))
|
||||||
|
# Run agent with retry logic
|
||||||
|
result = run_agent_with_retry(
|
||||||
|
agent,
|
||||||
|
prompt,
|
||||||
|
{"configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": 100}
|
||||||
|
)
|
||||||
|
|
||||||
|
success = True
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console.print("\n[yellow]Research interrupted by user[/yellow]")
|
||||||
|
success = False
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"\n[red]Error during research: {str(e)}[/red]")
|
||||||
|
success = False
|
||||||
|
|
||||||
|
# Gather results
|
||||||
|
return {
|
||||||
|
"facts": get_memory_value("key_facts"),
|
||||||
|
"files": list(get_related_files()),
|
||||||
|
"notes": get_memory_value("research_notes"),
|
||||||
|
"success": success
|
||||||
|
}
|
||||||
|
|
@ -84,6 +84,7 @@ def request_research_subtask(subtask: str) -> str:
|
||||||
"""Spawn a research subtask for investigation of a specific topic.
|
"""Spawn a research subtask for investigation of a specific topic.
|
||||||
|
|
||||||
Use this anytime you can to offload your work to specific things that need to be looked into.
|
Use this anytime you can to offload your work to specific things that need to be looked into.
|
||||||
|
Use this only when it's necessary to dig deeper into a specific topic.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
subtask: Detailed description of the research subtask
|
subtask: Detailed description of the research subtask
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue