diff --git a/ra_aid/__init__.py b/ra_aid/__init__.py index 20dd039..06aff82 100644 --- a/ra_aid/__init__.py +++ b/ra_aid/__init__.py @@ -2,6 +2,7 @@ from .__version__ import __version__ from .console.formatting import print_stage_header, print_task_header, print_error from .console.output import print_agent_output from .text.processing import truncate_output +from .agent_utils import run_agent_with_retry __all__ = [ 'print_stage_header', @@ -9,5 +10,6 @@ __all__ = [ 'print_agent_output', 'truncate_output', 'print_error', + 'run_agent_with_retry', '__version__' ] diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 19196cd..38f115c 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -1,15 +1,12 @@ import argparse import sys -from typing import Optional from rich.panel import Panel -from rich.markdown import Markdown from rich.console import Console -from langchain_core.messages import HumanMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from ra_aid.env import validate_environment from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value -from ra_aid import print_agent_output, print_stage_header, print_task_header, print_error +from ra_aid import print_stage_header, print_task_header, print_error, run_agent_with_retry from ra_aid.prompts import ( RESEARCH_PROMPT, PLANNING_PROMPT, @@ -22,8 +19,6 @@ from ra_aid.prompts import ( HUMAN_PROMPT_SECTION_PLANNING, HUMAN_PROMPT_SECTION_IMPLEMENTATION ) -import time -from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError from ra_aid.llm import initialize_llm from ra_aid.tool_configs import ( @@ -132,51 +127,6 @@ def is_stage_requested(stage: str) -> bool: return _global_memory.get('implementation_requested', False) return False -def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: - """Run an agent with retry logic for internal server errors and task completion handling. - - Args: - agent: The agent to run - prompt: The prompt to send to the agent - config: Configuration dictionary for the agent - - Returns: - Optional[str]: The completion message if task was completed, None otherwise - - Handles API errors with exponential backoff retry logic and checks for task - completion after each chunk of output. - """ - max_retries = 20 - base_delay = 1 # Initial delay in seconds - - for attempt in range(max_retries): - try: - for chunk in agent.stream( - {"messages": [HumanMessage(content=prompt)]}, - config - ): - print_agent_output(chunk) - - # Check for task completion after each chunk - if _global_memory.get('task_completed'): - completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.') - console.print(Panel( - Markdown(completion_msg), - title="✅ Task Completed", - style="green" - )) - return completion_msg - break - except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e: - if attempt == max_retries - 1: - raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}") - - delay = base_delay * (2 ** attempt) # Exponential backoff - error_type = e.__class__.__name__ - print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})") - time.sleep(delay) - continue - def run_implementation_stage(base_task, tasks, plan, related_files, model, expert_enabled: bool): """Run implementation stage with a distinct agent for each task.""" if not is_stage_requested('implementation'): @@ -216,7 +166,6 @@ def run_implementation_stage(base_task, tasks, plan, related_files, model, exper # Run agent for this task run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100}) - def run_research_subtasks(base_task: str, config: dict, model, expert_enabled: bool): """Run research subtasks with separate agents.""" subtasks = _global_memory.get('research_subtasks', []) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py new file mode 100644 index 0000000..9a97491 --- /dev/null +++ b/ra_aid/agent_utils.py @@ -0,0 +1,69 @@ +"""Utility functions for working with agents.""" + +import time +from typing import Optional + +from langchain_core.messages import HumanMessage +from langchain_core.messages import BaseMessage +from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel + +from ra_aid.tools.memory import _global_memory + +console = Console() + +def print_agent_output(chunk: dict[str, BaseMessage]) -> None: + """Print agent output chunks.""" + if chunk.get("delta") and chunk["delta"].content: + console.print(chunk["delta"].content, end="", style="blue") + +def print_error(msg: str) -> None: + """Print error messages.""" + console.print(f"\n{msg}", style="red") + +def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: + """Run an agent with retry logic for internal server errors and task completion handling. + + Args: + agent: The agent to run + prompt: The prompt to send to the agent + config: Configuration dictionary for the agent + + Returns: + Optional[str]: The completion message if task was completed, None otherwise + + Handles API errors with exponential backoff retry logic and checks for task + completion after each chunk of output. + """ + max_retries = 20 + base_delay = 1 # Initial delay in seconds + + for attempt in range(max_retries): + try: + for chunk in agent.stream( + {"messages": [HumanMessage(content=prompt)]}, + config + ): + print_agent_output(chunk) + + # Check for task completion after each chunk + if _global_memory.get('task_completed'): + completion_msg = _global_memory.get('completion_message', 'Task was completed successfully.') + console.print(Panel( + Markdown(completion_msg), + title="✅ Task Completed", + style="green" + )) + return completion_msg + break + except (InternalServerError, APITimeoutError, RateLimitError, APIError) as e: + if attempt == max_retries - 1: + raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {str(e)}") + + delay = base_delay * (2 ** attempt) # Exponential backoff + error_type = e.__class__.__name__ + print_error(f"Encountered {error_type}: {str(e)}. Retrying in {delay} seconds... (Attempt {attempt + 1}/{max_retries})") + time.sleep(delay) + continue diff --git a/ra_aid/llm.py b/ra_aid/llm.py index b6438cf..35d0376 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -23,27 +23,23 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel: return ChatOpenAI( api_key=os.getenv("OPENAI_API_KEY"), model=model_name, - temperature=0 ) elif provider == "anthropic": return ChatAnthropic( api_key=os.getenv("ANTHROPIC_API_KEY"), model_name=model_name, - temperature=0 ) elif provider == "openrouter": return ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model=model_name, - temperature=0 ) elif provider == "openai-compatible": return ChatOpenAI( api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE"), model=model_name, - temperature=0 ) else: raise ValueError(f"Unsupported provider: {provider}") diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index 825e644..a2dfc15 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -9,6 +9,7 @@ from ra_aid.tools import ( swap_task_order, monorepo_detected, existing_project_detected, ui_detected ) from ra_aid.tools.memory import one_shot_completed +from ra_aid.tools.agent import request_research # Read-only tools that don't modify system state def get_read_only_tools(human_interaction: bool = False) -> list: @@ -61,6 +62,9 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True, if expert_enabled: tools.extend(EXPERT_TOOLS) + # Add chat-specific tools + tools.append(request_research) + return tools def get_planning_tools(expert_enabled: bool = True) -> list: @@ -103,18 +107,13 @@ def get_chat_tools(expert_enabled: bool = True) -> list: Chat mode includes research and implementation capabilities but excludes complex planning tools. Human interaction is always enabled. """ - # Start with read-only tools and always include human interaction - tools = get_read_only_tools(human_interaction=True).copy() - - # Add implementation capability - tools.extend(MODIFICATION_TOOLS) - - # Add research tools except for subtask management - research_tools = [t for t in RESEARCH_TOOLS if t.name != 'request_research_subtask'] - tools.extend(research_tools) + tools = [ + ask_human, + request_research + ] # Add expert tools if enabled if expert_enabled: tools.extend(EXPERT_TOOLS) - + return tools diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py new file mode 100644 index 0000000..1640738 --- /dev/null +++ b/ra_aid/tools/agent.py @@ -0,0 +1,95 @@ +"""Tools for spawning and managing sub-agents.""" + +from langchain_core.tools import tool +from typing import Dict, Any, List, Optional +import uuid +from rich.console import Console +from rich.panel import Panel +from rich.markdown import Markdown +from langgraph.prebuilt import create_react_agent +from langgraph.checkpoint.memory import MemorySaver +from ra_aid.tools.memory import _global_memory +from ra_aid import run_agent_with_retry +from ..prompts import RESEARCH_PROMPT +from .memory import get_memory_value, get_related_files +from ..llm import initialize_llm + +console = Console() + +@tool("request_research") +def request_research(query: str) -> Dict[str, Any]: + """Spawn a research-only agent to investigate the given query. + + Args: + query: The research question or project description + + Returns: + Dict containing: + - notes: Research notes from the agent + - facts: Current key facts + - files: Related files + - success: Whether completed or interrupted + """ + # Initialize model and memory + model = initialize_llm("anthropic", "claude-3-sonnet-20240229") + memory = MemorySaver() + memory.memory = _global_memory + + # Configure research tools + from ..tool_configs import get_research_tools + tools = get_research_tools(research_only=True, expert_enabled=True) + + # Basic config matching main process + config = { + "thread_id": str(uuid.uuid4()), + "memory": memory, + "model": model + } + + from ra_aid.prompts import ( + RESEARCH_PROMPT, + EXPERT_PROMPT_SECTION_RESEARCH, + HUMAN_PROMPT_SECTION_RESEARCH + ) + + # Create research agent + config = _global_memory.get('config', {}) + expert_enabled = config.get('expert_enabled', False) + hil = config.get('hil', False) + + expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" + human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" + + agent = create_react_agent(model, tools) + + prompt = RESEARCH_PROMPT.format( + base_task=query, + research_only_note='', + expert_section=expert_section, + human_section=human_section + ) + + try: + console.print(Panel(Markdown(query), title="🔬 Research Task")) + # Run agent with retry logic + result = run_agent_with_retry( + agent, + prompt, + {"configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": 100} + ) + + success = True + except KeyboardInterrupt: + console.print("\n[yellow]Research interrupted by user[/yellow]") + success = False + except Exception as e: + console.print(f"\n[red]Error during research: {str(e)}[/red]") + success = False + + # Gather results + return { + "facts": get_memory_value("key_facts"), + "files": list(get_related_files()), + "notes": get_memory_value("research_notes"), + "success": success + } diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index 82a24d3..454f152 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -84,6 +84,7 @@ def request_research_subtask(subtask: str) -> str: """Spawn a research subtask for investigation of a specific topic. Use this anytime you can to offload your work to specific things that need to be looked into. + Use this only when it's necessary to dig deeper into a specific topic. Args: subtask: Detailed description of the research subtask