From ecb6796008bb258c4d9e5968ef60ee1276faef62 Mon Sep 17 00:00:00 2001 From: user Date: Sat, 21 Dec 2024 14:10:55 -0500 Subject: [PATCH] Extract run_planning_agent. --- ra_aid/__main__.py | 34 +++++++------------ ra_aid/agent_utils.py | 79 +++++++++++++++++++++++++++++++++++++++++-- ra_aid/tools/agent.py | 4 +++ 3 files changed, 94 insertions(+), 23 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 727eb7f..20fbd5e 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -8,8 +8,11 @@ from langgraph.prebuilt import create_react_agent from ra_aid.env import validate_environment from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value from ra_aid import print_stage_header, print_error -from ra_aid.agent_utils import run_agent_with_retry -from ra_aid.agent_utils import run_research_agent +from ra_aid.agent_utils import ( + run_agent_with_retry, + run_research_agent, + run_planning_agent +) from ra_aid.prompts import ( PLANNING_PROMPT, CHAT_PROMPT, @@ -207,26 +210,15 @@ def main(): # Proceed with planning and implementation if not an informational query if not is_informational_query(): - print_stage_header("Planning Stage") - - # Create planning agent - planning_agent = create_react_agent(model, get_planning_tools(expert_enabled=expert_enabled), checkpointer=planning_memory) - - expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" - human_section = HUMAN_PROMPT_SECTION_PLANNING if args.hil else "" - planning_prompt = PLANNING_PROMPT.format( - expert_section=expert_section, - human_section=human_section, - base_task=base_task, - research_notes=get_memory_value('research_notes'), - related_files="\n".join(get_related_files()), - key_facts=get_memory_value('key_facts'), - key_snippets=get_memory_value('key_snippets'), - research_only_note='' if args.research_only else ' Only request implementation if the user explicitly asked for changes to be made.' - ) - # Run planning agent - run_agent_with_retry(planning_agent, planning_prompt, config) + run_planning_agent( + base_task, + model, + expert_enabled=expert_enabled, + hil=args.hil, + memory=planning_memory, + config=config + ) except KeyboardInterrupt: console.print("\n[red]Operation cancelled by user[/red]") diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 11ed096..5b35c16 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -5,14 +5,22 @@ import uuid from typing import Optional, Any, List from langgraph.prebuilt import create_react_agent -from ra_aid.tool_configs import get_implementation_tools, get_research_tools +from ra_aid.console.formatting import print_stage_header +from ra_aid.tool_configs import ( + get_implementation_tools, + get_research_tools, + get_planning_tools +) from ra_aid.prompts import ( IMPLEMENTATION_PROMPT, EXPERT_PROMPT_SECTION_IMPLEMENTATION, HUMAN_PROMPT_SECTION_IMPLEMENTATION, EXPERT_PROMPT_SECTION_RESEARCH, RESEARCH_PROMPT, - HUMAN_PROMPT_SECTION_RESEARCH + HUMAN_PROMPT_SECTION_RESEARCH, + PLANNING_PROMPT, + EXPERT_PROMPT_SECTION_PLANNING, + HUMAN_PROMPT_SECTION_PLANNING ) from langgraph.checkpoint.memory import MemorySaver @@ -26,6 +34,7 @@ from rich.panel import Panel from ra_aid.tools.memory import ( _global_memory, get_memory_value, + get_related_files, ) from ra_aid.globals import RESEARCH_AGENT_RECURSION_LIMIT from ra_aid.tool_configs import get_research_tools @@ -127,6 +136,72 @@ def print_error(msg: str) -> None: """Print error messages.""" console.print(f"\n{msg}", style="red") +def run_planning_agent( + base_task: str, + model, + *, + expert_enabled: bool = False, + hil: bool = False, + memory: Optional[Any] = None, + config: Optional[dict] = None, + thread_id: Optional[str] = None +) -> Optional[str]: + """Run a planning agent to create implementation plans. + + Args: + base_task: The main task to plan implementation for + model: The LLM model to use + expert_enabled: Whether expert mode is enabled + hil: Whether human-in-the-loop mode is enabled + memory: Optional memory instance to use + config: Optional configuration dictionary + thread_id: Optional thread ID (defaults to new UUID) + + Returns: + Optional[str]: The completion message if planning completed successfully + """ + # Initialize memory if not provided + if memory is None: + memory = MemorySaver() + + # Set up thread ID + if thread_id is None: + thread_id = str(uuid.uuid4()) + + # Configure tools + tools = get_planning_tools(expert_enabled=expert_enabled) + + # Create agent + agent = create_react_agent(model, tools, checkpointer=memory) + + # Format prompt sections + expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" + human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" + + # Build prompt + planning_prompt = PLANNING_PROMPT.format( + expert_section=expert_section, + human_section=human_section, + base_task=base_task, + research_notes=get_memory_value('research_notes'), + related_files="\n".join(get_related_files()), + key_facts=get_memory_value('key_facts'), + key_snippets=get_memory_value('key_snippets'), + research_only_note='' if config.get('research_only') else ' Only request implementation if the user explicitly asked for changes to be made.' + ) + + # Set up configuration + run_config = { + "configurable": {"thread_id": thread_id}, + "recursion_limit": 100 + } + if config: + run_config.update(config) + + # Run agent with retry logic + print_stage_header("Planning Stage") + return run_agent_with_retry(agent, planning_prompt, run_config) + def run_task_implementation_agent( base_task: str, tasks: list, diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 97087f8..8b32003 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -104,6 +104,10 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: # Get completion message if available completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) + + # Clear completion state from global memory + _global_memory['completion_message'] = '' + _global_memory['completion_state'] = False return { "facts": get_memory_value("key_facts"),