Allow chat agent to spawn research agent.
This commit is contained in:
parent
f712fcf9c8
commit
1fc64c5151
|
|
@ -2,6 +2,7 @@ from .__version__ import __version__
|
|||
from .console.formatting import print_stage_header, print_task_header, print_error
|
||||
from .console.output import print_agent_output
|
||||
from .text.processing import truncate_output
|
||||
from .agent_utils import run_agent_with_retry
|
||||
|
||||
__all__ = [
|
||||
'print_stage_header',
|
||||
|
|
@ -9,5 +10,6 @@ __all__ = [
|
|||
'print_agent_output',
|
||||
'truncate_output',
|
||||
'print_error',
|
||||
'run_agent_with_retry',
|
||||
'__version__'
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,15 +1,12 @@
|
|||
import argparse
|
||||
import sys
|
||||
from typing import Optional
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.console import Console
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value
|
||||
from ra_aid import print_agent_output, print_stage_header, print_task_header, print_error
|
||||
from ra_aid import print_stage_header, print_task_header, print_error, run_agent_with_retry
|
||||
from ra_aid.prompts import (
|
||||
RESEARCH_PROMPT,
|
||||
PLANNING_PROMPT,
|
||||
|
|
@ -22,8 +19,6 @@ from ra_aid.prompts import (
|
|||
HUMAN_PROMPT_SECTION_PLANNING,
|
||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
||||
)
|
||||
import time
|
||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||
from ra_aid.llm import initialize_llm
|
||||
|
||||
from ra_aid.tool_configs import (
|
||||
|
|
@ -132,51 +127,6 @@ def is_stage_requested(stage: str) -> bool:
|
|||
return _global_memory.get('implementation_requested', False)
|
||||
return False
|
||||
|
||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||
"""Run an agent with retry logic for internal server errors and task completion handling.
|
||||
|
||||
Args:
|
||||
agent: The agent to run
|
||||
prompt: The prompt to send to the agent
|
||||
config: Configuration dictionary for the agent
|
||||
|
||||
Returns:
|
||||
Optional[str]: The completion message if task was completed, None otherwise
|
||||
|
||||
Handles API errors with exponential backoff retry logic and checks for task
|
||||
completion after each chunk of output.
|
||||
"""
|
||||
max_retries = 20
|
||||
base_delay = 1 # Initial delay in seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
for chunk in agent.stream(
|
||||
{"messages": [HumanMessage(content=prompt)]},
|
||||
config
|
||||
):
|
||||
print_agent_output(chunk)
|
||||
|
||||
# Check for task completion after each chunk
|
||||
if _global_memory.get('task_completed'):
|
||||
completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.')
|
||||
console.print(Panel(
|
||||
Markdown(completion_msg),
|
||||
title="✅ Task Completed",
|
||||
style="green"
|
||||
))
|
||||
return completion_msg
|
||||
break
|
||||
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
|
||||
|
||||
delay = base_delay * (2 ** attempt) # Exponential backoff
|
||||
error_type = e.__class__.__name__
|
||||
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
def run_implementation_stage(base_task, tasks, plan, related_files, model, expert_enabled: bool):
|
||||
"""Run implementation stage with a distinct agent for each task."""
|
||||
if not is_stage_requested('implementation'):
|
||||
|
|
@ -216,7 +166,6 @@ def run_implementation_stage(base_task, tasks, plan, related_files, model, exper
|
|||
# Run agent for this task
|
||||
run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100})
|
||||
|
||||
|
||||
def run_research_subtasks(base_task: str, config: dict, model, expert_enabled: bool):
|
||||
"""Run research subtasks with separate agents."""
|
||||
subtasks = _global_memory.get('research_subtasks', [])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
"""Utility functions for working with agents."""
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
console = Console()
|
||||
|
||||
def print_agent_output(chunk: dict[str, BaseMessage]) -> None:
|
||||
"""Print agent output chunks."""
|
||||
if chunk.get("delta") and chunk["delta"].content:
|
||||
console.print(chunk["delta"].content, end="", style="blue")
|
||||
|
||||
def print_error(msg: str) -> None:
|
||||
"""Print error messages."""
|
||||
console.print(f"\n{msg}", style="red")
|
||||
|
||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||
"""Run an agent with retry logic for internal server errors and task completion handling.
|
||||
|
||||
Args:
|
||||
agent: The agent to run
|
||||
prompt: The prompt to send to the agent
|
||||
config: Configuration dictionary for the agent
|
||||
|
||||
Returns:
|
||||
Optional[str]: The completion message if task was completed, None otherwise
|
||||
|
||||
Handles API errors with exponential backoff retry logic and checks for task
|
||||
completion after each chunk of output.
|
||||
"""
|
||||
max_retries = 20
|
||||
base_delay = 1 # Initial delay in seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
for chunk in agent.stream(
|
||||
{"messages": [HumanMessage(content=prompt)]},
|
||||
config
|
||||
):
|
||||
print_agent_output(chunk)
|
||||
|
||||
# Check for task completion after each chunk
|
||||
if _global_memory.get('task_completed'):
|
||||
completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.')
|
||||
console.print(Panel(
|
||||
Markdown(completion_msg),
|
||||
title="✅ Task Completed",
|
||||
style="green"
|
||||
))
|
||||
return completion_msg
|
||||
break
|
||||
except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}")
|
||||
|
||||
delay = base_delay * (2 ** attempt) # Exponential backoff
|
||||
error_type = e.__class__.__name__
|
||||
print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})")
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
|
@ -23,27 +23,23 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
|
|||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model=model_name,
|
||||
temperature=0
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return ChatAnthropic(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model_name=model_name,
|
||||
temperature=0
|
||||
)
|
||||
elif provider == "openrouter":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model=model_name,
|
||||
temperature=0
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url=os.getenv("OPENAI_API_BASE"),
|
||||
model=model_name,
|
||||
temperature=0
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from ra_aid.tools import (
|
|||
swap_task_order, monorepo_detected, existing_project_detected, ui_detected
|
||||
)
|
||||
from ra_aid.tools.memory import one_shot_completed
|
||||
from ra_aid.tools.agent import request_research
|
||||
|
||||
# Read-only tools that don't modify system state
|
||||
def get_read_only_tools(human_interaction: bool = False) -> list:
|
||||
|
|
@ -61,6 +62,9 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True,
|
|||
if expert_enabled:
|
||||
tools.extend(EXPERT_TOOLS)
|
||||
|
||||
# Add chat-specific tools
|
||||
tools.append(request_research)
|
||||
|
||||
return tools
|
||||
|
||||
def get_planning_tools(expert_enabled: bool = True) -> list:
|
||||
|
|
@ -103,15 +107,10 @@ def get_chat_tools(expert_enabled: bool = True) -> list:
|
|||
Chat mode includes research and implementation capabilities but excludes
|
||||
complex planning tools. Human interaction is always enabled.
|
||||
"""
|
||||
# Start with read-only tools and always include human interaction
|
||||
tools = get_read_only_tools(human_interaction=True).copy()
|
||||
|
||||
# Add implementation capability
|
||||
tools.extend(MODIFICATION_TOOLS)
|
||||
|
||||
# Add research tools except for subtask management
|
||||
research_tools = [t for t in RESEARCH_TOOLS if t.name != 'request_research_subtask']
|
||||
tools.extend(research_tools)
|
||||
tools = [
|
||||
ask_human,
|
||||
request_research
|
||||
]
|
||||
|
||||
# Add expert tools if enabled
|
||||
if expert_enabled:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
"""Tools for spawning and managing sub-agents."""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from typing import Dict, Any, List, Optional
|
||||
import uuid
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid import run_agent_with_retry
|
||||
from ..prompts import RESEARCH_PROMPT
|
||||
from .memory import get_memory_value, get_related_files
|
||||
from ..llm import initialize_llm
|
||||
|
||||
console = Console()
|
||||
|
||||
@tool("request_research")
|
||||
def request_research(query: str) -> Dict[str, Any]:
|
||||
"""Spawn a research-only agent to investigate the given query.
|
||||
|
||||
Args:
|
||||
query: The research question or project description
|
||||
|
||||
Returns:
|
||||
Dict containing:
|
||||
- notes: Research notes from the agent
|
||||
- facts: Current key facts
|
||||
- files: Related files
|
||||
- success: Whether completed or interrupted
|
||||
"""
|
||||
# Initialize model and memory
|
||||
model = initialize_llm("anthropic", "claude-3-sonnet-20240229")
|
||||
memory = MemorySaver()
|
||||
memory.memory = _global_memory
|
||||
|
||||
# Configure research tools
|
||||
from ..tool_configs import get_research_tools
|
||||
tools = get_research_tools(research_only=True, expert_enabled=True)
|
||||
|
||||
# Basic config matching main process
|
||||
config = {
|
||||
"thread_id": str(uuid.uuid4()),
|
||||
"memory": memory,
|
||||
"model": model
|
||||
}
|
||||
|
||||
from ra_aid.prompts import (
|
||||
RESEARCH_PROMPT,
|
||||
EXPERT_PROMPT_SECTION_RESEARCH,
|
||||
HUMAN_PROMPT_SECTION_RESEARCH
|
||||
)
|
||||
|
||||
# Create research agent
|
||||
config = _global_memory.get('config', {})
|
||||
expert_enabled = config.get('expert_enabled', False)
|
||||
hil = config.get('hil', False)
|
||||
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||
|
||||
agent = create_react_agent(model, tools)
|
||||
|
||||
prompt = RESEARCH_PROMPT.format(
|
||||
base_task=query,
|
||||
research_only_note='',
|
||||
expert_section=expert_section,
|
||||
human_section=human_section
|
||||
)
|
||||
|
||||
try:
|
||||
console.print(Panel(Markdown(query), title="🔬 Research Task"))
|
||||
# Run agent with retry logic
|
||||
result = run_agent_with_retry(
|
||||
agent,
|
||||
prompt,
|
||||
{"configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": 100}
|
||||
)
|
||||
|
||||
success = True
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Research interrupted by user[/yellow]")
|
||||
success = False
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]Error during research: {str(e)}[/red]")
|
||||
success = False
|
||||
|
||||
# Gather results
|
||||
return {
|
||||
"facts": get_memory_value("key_facts"),
|
||||
"files": list(get_related_files()),
|
||||
"notes": get_memory_value("research_notes"),
|
||||
"success": success
|
||||
}
|
||||
|
|
@ -84,6 +84,7 @@ def request_research_subtask(subtask: str) -> str:
|
|||
"""Spawn a research subtask for investigation of a specific topic.
|
||||
|
||||
Use this anytime you can to offload your work to specific things that need to be looked into.
|
||||
Use this only when it's necessary to dig deeper into a specific topic.
|
||||
|
||||
Args:
|
||||
subtask: Detailed description of the research subtask
|
||||
|
|
|
|||
Loading…
Reference in New Issue