diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index a8797e8..60874a9 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -7,7 +7,8 @@ from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from ra_aid.env import validate_environment from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value -from ra_aid import print_stage_header, print_task_header, print_error, run_agent_with_retry +from ra_aid import print_stage_header, print_task_header, print_error +from ra_aid.agent_utils import run_agent_with_retry, run_task_implementation_agent from ra_aid.agent_utils import run_research_agent from ra_aid.prompts import ( PLANNING_PROMPT, @@ -140,28 +141,16 @@ def run_implementation_stage(base_task, tasks, plan, related_files, model, exper for i, task in enumerate(task_list, 1): print_task_header(task) - # Create a unique memory instance for this task - task_memory = MemorySaver() - - # Create a fresh agent for each task - task_agent = create_react_agent(model, get_implementation_tools(expert_enabled=expert_enabled), checkpointer=task_memory) - - # Construct task-specific prompt - expert_section = EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "" - human_section = HUMAN_PROMPT_SECTION_IMPLEMENTATION if _global_memory.get('config', {}).get('hil', False) else "" - task_prompt = (IMPLEMENTATION_PROMPT).format( - plan=plan, - key_facts=get_memory_value('key_facts'), - key_snippets=get_memory_value('key_snippets'), - task=task, - related_files="\n".join(related_files), + # Run implementation agent for this task + run_task_implementation_agent( base_task=base_task, - expert_section=expert_section, - human_section=human_section + tasks=task_list, + task=task, + plan=plan, + related_files=related_files, + model=model, + expert_enabled=expert_enabled ) - - # Run agent for this task - run_agent_with_retry(task_agent, task_prompt, {"configurable": {"thread_id": "abc123"}, "recursion_limit": 100}) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index d43423d..6be9f99 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -2,9 +2,18 @@ import time import uuid -from typing import Optional, Any +from typing import Optional, Any, List from langgraph.prebuilt import create_react_agent +from ra_aid.tool_configs import get_implementation_tools, get_research_tools +from ra_aid.prompts import ( + IMPLEMENTATION_PROMPT, + EXPERT_PROMPT_SECTION_IMPLEMENTATION, + HUMAN_PROMPT_SECTION_IMPLEMENTATION, + EXPERT_PROMPT_SECTION_RESEARCH, + RESEARCH_PROMPT, + HUMAN_PROMPT_SECTION_RESEARCH +) from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import HumanMessage @@ -14,7 +23,10 @@ from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel -from ra_aid.tools.memory import _global_memory +from ra_aid.tools.memory import ( + _global_memory, + get_memory_value, +) from ra_aid.globals import RESEARCH_AGENT_RECURSION_LIMIT from ra_aid.tool_configs import get_research_tools from ra_aid.prompts import ( @@ -64,7 +76,6 @@ def run_research_agent( # Initialize memory if not provided if memory is None: memory = MemorySaver() - memory.memory = _global_memory # Set up thread ID if thread_id is None: @@ -116,6 +127,73 @@ def print_error(msg: str) -> None: """Print error messages.""" console.print(f"\n{msg}", style="red") +def run_task_implementation_agent( + base_task: str, + tasks: list, + task: str, + plan: str, + related_files: list, + model, + *, + expert_enabled: bool = False, + memory: Optional[Any] = None, + config: Optional[dict] = None, + thread_id: Optional[str] = None +) -> Optional[str]: + """Run an implementation agent for a specific task. + + Args: + base_task: The main task being implemented + tasks: List of tasks to implement + plan: The implementation plan + related_files: List of related files + model: The LLM model to use + expert_enabled: Whether expert mode is enabled + memory: Optional memory instance to use + config: Optional configuration dictionary + thread_id: Optional thread ID (defaults to new UUID) + + Returns: + Optional[str]: The completion message if task completed successfully + """ + # Initialize memory if not provided + if memory is None: + memory = MemorySaver() + + # Set up thread ID + if thread_id is None: + thread_id = str(uuid.uuid4()) + + # Configure tools + tools = get_implementation_tools(expert_enabled=expert_enabled) + + # Create agent + agent = create_react_agent(model, tools, checkpointer=memory) + + # Build prompt + prompt = IMPLEMENTATION_PROMPT.format( + base_task=base_task, + task=task, + tasks=tasks, + plan=plan, + related_files=related_files, + key_facts=get_memory_value('key_facts'), + key_snippets=get_memory_value('key_snippets'), + expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", + human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION if _global_memory.get('config', {}).get('hil', False) else "" + ) + + # Set up configuration + run_config = { + "configurable": {"thread_id": thread_id}, + "recursion_limit": 100 + } + if config: + run_config.update(config) + + # Run agent with retry logic + return run_agent_with_retry(agent, prompt, run_config) + def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: """Run an agent with retry logic for internal server errors and task completion handling.