Extract run_planning_agent.
This commit is contained in:
parent
c65fe077b3
commit
ecb6796008
|
|
@ -8,8 +8,11 @@ from langgraph.prebuilt import create_react_agent
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value
|
from ra_aid.tools.memory import _global_memory, get_related_files, get_memory_value
|
||||||
from ra_aid import print_stage_header, print_error
|
from ra_aid import print_stage_header, print_error
|
||||||
from ra_aid.agent_utils import run_agent_with_retry
|
from ra_aid.agent_utils import (
|
||||||
from ra_aid.agent_utils import run_research_agent
|
run_agent_with_retry,
|
||||||
|
run_research_agent,
|
||||||
|
run_planning_agent
|
||||||
|
)
|
||||||
from ra_aid.prompts import (
|
from ra_aid.prompts import (
|
||||||
PLANNING_PROMPT,
|
PLANNING_PROMPT,
|
||||||
CHAT_PROMPT,
|
CHAT_PROMPT,
|
||||||
|
|
@ -207,26 +210,15 @@ def main():
|
||||||
|
|
||||||
# Proceed with planning and implementation if not an informational query
|
# Proceed with planning and implementation if not an informational query
|
||||||
if not is_informational_query():
|
if not is_informational_query():
|
||||||
print_stage_header("Planning Stage")
|
|
||||||
|
|
||||||
# Create planning agent
|
|
||||||
planning_agent = create_react_agent(model, get_planning_tools(expert_enabled=expert_enabled), checkpointer=planning_memory)
|
|
||||||
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
|
||||||
human_section = HUMAN_PROMPT_SECTION_PLANNING if args.hil else ""
|
|
||||||
planning_prompt = PLANNING_PROMPT.format(
|
|
||||||
expert_section=expert_section,
|
|
||||||
human_section=human_section,
|
|
||||||
base_task=base_task,
|
|
||||||
research_notes=get_memory_value('research_notes'),
|
|
||||||
related_files="\n".join(get_related_files()),
|
|
||||||
key_facts=get_memory_value('key_facts'),
|
|
||||||
key_snippets=get_memory_value('key_snippets'),
|
|
||||||
research_only_note='' if args.research_only else ' Only request implementation if the user explicitly asked for changes to be made.'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run planning agent
|
# Run planning agent
|
||||||
run_agent_with_retry(planning_agent, planning_prompt, config)
|
run_planning_agent(
|
||||||
|
base_task,
|
||||||
|
model,
|
||||||
|
expert_enabled=expert_enabled,
|
||||||
|
hil=args.hil,
|
||||||
|
memory=planning_memory,
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\n[red]Operation cancelled by user[/red]")
|
console.print("\n[red]Operation cancelled by user[/red]")
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,22 @@ import uuid
|
||||||
from typing import Optional, Any, List
|
from typing import Optional, Any, List
|
||||||
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from ra_aid.tool_configs import get_implementation_tools, get_research_tools
|
from ra_aid.console.formatting import print_stage_header
|
||||||
|
from ra_aid.tool_configs import (
|
||||||
|
get_implementation_tools,
|
||||||
|
get_research_tools,
|
||||||
|
get_planning_tools
|
||||||
|
)
|
||||||
from ra_aid.prompts import (
|
from ra_aid.prompts import (
|
||||||
IMPLEMENTATION_PROMPT,
|
IMPLEMENTATION_PROMPT,
|
||||||
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
|
EXPERT_PROMPT_SECTION_IMPLEMENTATION,
|
||||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
|
HUMAN_PROMPT_SECTION_IMPLEMENTATION,
|
||||||
EXPERT_PROMPT_SECTION_RESEARCH,
|
EXPERT_PROMPT_SECTION_RESEARCH,
|
||||||
RESEARCH_PROMPT,
|
RESEARCH_PROMPT,
|
||||||
HUMAN_PROMPT_SECTION_RESEARCH
|
HUMAN_PROMPT_SECTION_RESEARCH,
|
||||||
|
PLANNING_PROMPT,
|
||||||
|
EXPERT_PROMPT_SECTION_PLANNING,
|
||||||
|
HUMAN_PROMPT_SECTION_PLANNING
|
||||||
)
|
)
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
|
|
@ -26,6 +34,7 @@ from rich.panel import Panel
|
||||||
from ra_aid.tools.memory import (
|
from ra_aid.tools.memory import (
|
||||||
_global_memory,
|
_global_memory,
|
||||||
get_memory_value,
|
get_memory_value,
|
||||||
|
get_related_files,
|
||||||
)
|
)
|
||||||
from ra_aid.globals import RESEARCH_AGENT_RECURSION_LIMIT
|
from ra_aid.globals import RESEARCH_AGENT_RECURSION_LIMIT
|
||||||
from ra_aid.tool_configs import get_research_tools
|
from ra_aid.tool_configs import get_research_tools
|
||||||
|
|
@ -127,6 +136,72 @@ def print_error(msg: str) -> None:
|
||||||
"""Print error messages."""
|
"""Print error messages."""
|
||||||
console.print(f"\n{msg}", style="red")
|
console.print(f"\n{msg}", style="red")
|
||||||
|
|
||||||
|
def run_planning_agent(
|
||||||
|
base_task: str,
|
||||||
|
model,
|
||||||
|
*,
|
||||||
|
expert_enabled: bool = False,
|
||||||
|
hil: bool = False,
|
||||||
|
memory: Optional[Any] = None,
|
||||||
|
config: Optional[dict] = None,
|
||||||
|
thread_id: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Run a planning agent to create implementation plans.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_task: The main task to plan implementation for
|
||||||
|
model: The LLM model to use
|
||||||
|
expert_enabled: Whether expert mode is enabled
|
||||||
|
hil: Whether human-in-the-loop mode is enabled
|
||||||
|
memory: Optional memory instance to use
|
||||||
|
config: Optional configuration dictionary
|
||||||
|
thread_id: Optional thread ID (defaults to new UUID)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The completion message if planning completed successfully
|
||||||
|
"""
|
||||||
|
# Initialize memory if not provided
|
||||||
|
if memory is None:
|
||||||
|
memory = MemorySaver()
|
||||||
|
|
||||||
|
# Set up thread ID
|
||||||
|
if thread_id is None:
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Configure tools
|
||||||
|
tools = get_planning_tools(expert_enabled=expert_enabled)
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
agent = create_react_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
|
# Format prompt sections
|
||||||
|
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||||
|
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
||||||
|
|
||||||
|
# Build prompt
|
||||||
|
planning_prompt = PLANNING_PROMPT.format(
|
||||||
|
expert_section=expert_section,
|
||||||
|
human_section=human_section,
|
||||||
|
base_task=base_task,
|
||||||
|
research_notes=get_memory_value('research_notes'),
|
||||||
|
related_files="\n".join(get_related_files()),
|
||||||
|
key_facts=get_memory_value('key_facts'),
|
||||||
|
key_snippets=get_memory_value('key_snippets'),
|
||||||
|
research_only_note='' if config.get('research_only') else ' Only request implementation if the user explicitly asked for changes to be made.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up configuration
|
||||||
|
run_config = {
|
||||||
|
"configurable": {"thread_id": thread_id},
|
||||||
|
"recursion_limit": 100
|
||||||
|
}
|
||||||
|
if config:
|
||||||
|
run_config.update(config)
|
||||||
|
|
||||||
|
# Run agent with retry logic
|
||||||
|
print_stage_header("Planning Stage")
|
||||||
|
return run_agent_with_retry(agent, planning_prompt, run_config)
|
||||||
|
|
||||||
def run_task_implementation_agent(
|
def run_task_implementation_agent(
|
||||||
base_task: str,
|
base_task: str,
|
||||||
tasks: list,
|
tasks: list,
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,10 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]:
|
||||||
# Get completion message if available
|
# Get completion message if available
|
||||||
completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None)
|
completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None)
|
||||||
|
|
||||||
|
# Clear completion state from global memory
|
||||||
|
_global_memory['completion_message'] = ''
|
||||||
|
_global_memory['completion_state'] = False
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"facts": get_memory_value("key_facts"),
|
"facts": get_memory_value("key_facts"),
|
||||||
"files": list(get_related_files()),
|
"files": list(get_related_files()),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue