318 lines
11 KiB
Python
318 lines
11 KiB
Python
import argparse
|
|
import sys
|
|
import uuid
|
|
from datetime import datetime
|
|
from rich.panel import Panel
|
|
from rich.console import Console
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
from ra_aid.env import validate_environment
|
|
from ra_aid.project_info import get_project_info, format_project_info, display_project_status
|
|
from ra_aid.tools.memory import _global_memory
|
|
from ra_aid.tools.human import ask_human
|
|
from ra_aid import print_stage_header, print_error
|
|
from ra_aid.tools.human import ask_human
|
|
from ra_aid.__version__ import __version__
|
|
from ra_aid.agent_utils import (
|
|
AgentInterrupt,
|
|
run_agent_with_retry,
|
|
run_research_agent,
|
|
run_planning_agent,
|
|
create_agent
|
|
)
|
|
from ra_aid.prompts import (
|
|
CHAT_PROMPT,
|
|
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
|
)
|
|
from ra_aid.llm import initialize_llm
|
|
from ra_aid.logging_config import setup_logging, get_logger
|
|
from ra_aid.tool_configs import (
|
|
get_chat_tools
|
|
)
|
|
from ra_aid.dependencies import check_dependencies
|
|
import os
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
def parse_arguments(args=None):
|
|
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible', 'gemini']
|
|
ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022'
|
|
OPENAI_DEFAULT_MODEL = 'gpt-4o'
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description='RA.Aid - AI Agent for executing programming and research tasks',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog='''
|
|
Examples:
|
|
ra-aid -m "Add error handling to the database module"
|
|
ra-aid -m "Explain the authentication flow" --research-only
|
|
'''
|
|
)
|
|
parser.add_argument(
|
|
'-m', '--message',
|
|
type=str,
|
|
help='The task or query to be executed by the agent'
|
|
)
|
|
parser.add_argument(
|
|
'--version',
|
|
action='version',
|
|
version=f'%(prog)s {__version__}',
|
|
help='Show program version number and exit'
|
|
)
|
|
parser.add_argument(
|
|
'--research-only',
|
|
action='store_true',
|
|
help='Only perform research without implementation'
|
|
)
|
|
parser.add_argument(
|
|
'--provider',
|
|
type=str,
|
|
default='openai' if (os.getenv('OPENAI_API_KEY') and not os.getenv('ANTHROPIC_API_KEY')) else 'anthropic',
|
|
choices=VALID_PROVIDERS,
|
|
help='The LLM provider to use'
|
|
)
|
|
parser.add_argument(
|
|
'--model',
|
|
type=str,
|
|
help='The model name to use'
|
|
)
|
|
parser.add_argument(
|
|
'--cowboy-mode',
|
|
action='store_true',
|
|
help='Skip interactive approval for shell commands'
|
|
)
|
|
parser.add_argument(
|
|
'--expert-provider',
|
|
type=str,
|
|
default='openai',
|
|
choices=VALID_PROVIDERS,
|
|
help='The LLM provider to use for expert knowledge queries (default: openai)'
|
|
)
|
|
parser.add_argument(
|
|
'--expert-model',
|
|
type=str,
|
|
help='The model name to use for expert knowledge queries (required for non-OpenAI providers)'
|
|
)
|
|
parser.add_argument(
|
|
'--hil', '-H',
|
|
action='store_true',
|
|
help='Enable human-in-the-loop mode, where the agent can prompt the user for additional information.'
|
|
)
|
|
parser.add_argument(
|
|
'--chat',
|
|
action='store_true',
|
|
help='Enable chat mode with direct human interaction (implies --hil)'
|
|
)
|
|
parser.add_argument(
|
|
'--verbose',
|
|
action='store_true',
|
|
help='Enable verbose logging output'
|
|
)
|
|
parser.add_argument(
|
|
'--temperature',
|
|
type=float,
|
|
help='LLM temperature (0.0-2.0). Controls randomness in responses',
|
|
default=None
|
|
)
|
|
|
|
if args is None:
|
|
args = sys.argv[1:]
|
|
parsed_args = parser.parse_args(args)
|
|
|
|
# Set hil=True when chat mode is enabled
|
|
if parsed_args.chat:
|
|
parsed_args.hil = True
|
|
|
|
# Validate provider
|
|
if parsed_args.provider not in VALID_PROVIDERS:
|
|
parser.error(f"Invalid provider: {parsed_args.provider}")
|
|
# Handle model defaults and requirements
|
|
|
|
if parsed_args.provider == "openai":
|
|
parsed_args.model = parsed_args.model or OPENAI_DEFAULT_MODEL
|
|
if parsed_args.provider == 'anthropic':
|
|
# Always use default model for Anthropic
|
|
parsed_args.model = ANTHROPIC_DEFAULT_MODEL
|
|
elif not parsed_args.model and not parsed_args.research_only:
|
|
# Require model for other providers unless in research mode
|
|
parser.error(f"--model is required when using provider '{parsed_args.provider}'")
|
|
|
|
# Validate expert model requirement
|
|
if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only:
|
|
parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'")
|
|
|
|
# Validate temperature range if provided
|
|
if parsed_args.temperature is not None and not (0.0 <= parsed_args.temperature <= 2.0):
|
|
parser.error('Temperature must be between 0.0 and 2.0')
|
|
|
|
return parsed_args
|
|
|
|
# Create console instance
|
|
console = Console()
|
|
|
|
# Create individual memory objects for each agent
|
|
research_memory = MemorySaver()
|
|
planning_memory = MemorySaver()
|
|
implementation_memory = MemorySaver()
|
|
|
|
|
|
def is_informational_query() -> bool:
|
|
"""Determine if the current query is informational based on implementation_requested state."""
|
|
return _global_memory.get('config', {}).get('research_only', False) or not is_stage_requested('implementation')
|
|
|
|
def is_stage_requested(stage: str) -> bool:
|
|
"""Check if a stage has been requested to proceed."""
|
|
if stage == 'implementation':
|
|
return _global_memory.get('implementation_requested', False)
|
|
return False
|
|
|
|
def main():
|
|
"""Main entry point for the ra-aid command line tool."""
|
|
args = parse_arguments()
|
|
setup_logging(args.verbose)
|
|
logger.debug("Starting RA.Aid with arguments: %s", args)
|
|
|
|
try:
|
|
# Check dependencies before proceeding
|
|
check_dependencies()
|
|
|
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing
|
|
logger.debug("Environment validation successful")
|
|
|
|
if expert_missing:
|
|
console.print(Panel(
|
|
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" +
|
|
"\n".join(f"- {m}" for m in expert_missing) +
|
|
"\nSet the required environment variables or args to enable expert mode.",
|
|
title="Expert Tools Disabled",
|
|
style="yellow"
|
|
))
|
|
|
|
if web_research_missing:
|
|
console.print(Panel(
|
|
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" +
|
|
"\n".join(f"- {m}" for m in web_research_missing) +
|
|
"\nSet the required environment variables to enable web research.",
|
|
title="Web Research Disabled",
|
|
style="yellow"
|
|
))
|
|
|
|
# Create the base model after validation
|
|
model = initialize_llm(args.provider, args.model, temperature=args.temperature)
|
|
|
|
# Handle chat mode
|
|
if args.chat:
|
|
if args.research_only:
|
|
print_error("Chat mode cannot be used with --research-only")
|
|
sys.exit(1)
|
|
|
|
print_stage_header("Chat Mode")
|
|
|
|
# Get project info
|
|
try:
|
|
project_info = get_project_info(".", file_limit=2000)
|
|
formatted_project_info = format_project_info(project_info)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get project info: {e}")
|
|
formatted_project_info = ""
|
|
|
|
# Get initial request from user
|
|
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
|
|
|
# Get working directory and current date
|
|
working_directory = os.getcwd()
|
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
|
|
|
# Run chat agent with CHAT_PROMPT
|
|
config = {
|
|
"configurable": {"thread_id": uuid.uuid4()},
|
|
"recursion_limit": 100,
|
|
"chat_mode": True,
|
|
"cowboy_mode": args.cowboy_mode,
|
|
"hil": True, # Always true in chat mode
|
|
"web_research_enabled": web_research_enabled,
|
|
"initial_request": initial_request
|
|
}
|
|
|
|
# Store config in global memory
|
|
_global_memory['config'] = config
|
|
_global_memory['config']['provider'] = args.provider
|
|
_global_memory['config']['model'] = args.model
|
|
_global_memory['config']['expert_provider'] = args.expert_provider
|
|
_global_memory['config']['expert_model'] = args.expert_model
|
|
|
|
# Create chat agent with appropriate tools
|
|
chat_agent = create_agent(
|
|
model,
|
|
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
|
checkpointer=MemorySaver()
|
|
)
|
|
|
|
# Run chat agent and exit
|
|
run_agent_with_retry(chat_agent, CHAT_PROMPT.format(
|
|
initial_request=initial_request,
|
|
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else "",
|
|
working_directory=working_directory,
|
|
current_date=current_date,
|
|
project_info=formatted_project_info
|
|
), config)
|
|
return
|
|
|
|
# Validate message is provided
|
|
if not args.message:
|
|
print_error("--message is required")
|
|
sys.exit(1)
|
|
|
|
base_task = args.message
|
|
config = {
|
|
"configurable": {"thread_id": uuid.uuid4()},
|
|
"recursion_limit": 100,
|
|
"research_only": args.research_only,
|
|
"cowboy_mode": args.cowboy_mode,
|
|
"web_research_enabled": web_research_enabled
|
|
}
|
|
|
|
# Store config in global memory for access by is_informational_query
|
|
_global_memory['config'] = config
|
|
|
|
# Store model configuration
|
|
_global_memory['config']['provider'] = args.provider
|
|
_global_memory['config']['model'] = args.model
|
|
|
|
# Store expert provider and model in config
|
|
_global_memory['config']['expert_provider'] = args.expert_provider
|
|
_global_memory['config']['expert_model'] = args.expert_model
|
|
|
|
# Run research stage
|
|
print_stage_header("Research Stage")
|
|
|
|
run_research_agent(
|
|
base_task,
|
|
model,
|
|
expert_enabled=expert_enabled,
|
|
research_only=args.research_only,
|
|
hil=args.hil,
|
|
memory=research_memory,
|
|
config=config
|
|
)
|
|
|
|
# Proceed with planning and implementation if not an informational query
|
|
if not is_informational_query():
|
|
# Run planning agent
|
|
run_planning_agent(
|
|
base_task,
|
|
model,
|
|
expert_enabled=expert_enabled,
|
|
hil=args.hil,
|
|
memory=planning_memory,
|
|
config=config
|
|
)
|
|
|
|
except (KeyboardInterrupt, AgentInterrupt):
|
|
print()
|
|
print(" 👋 Bye!")
|
|
print()
|
|
sys.exit(0)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|