diff --git a/README.md b/README.md index ff4279b..6aa9851 100644 --- a/README.md +++ b/README.md @@ -162,16 +162,20 @@ ra-aid -m "Add new feature" --verbose ### Command Line Options -- `-m, --message`: The task or query to be executed (required) +- `-m, --message`: The task or query to be executed (required except in chat mode) - `--research-only`: Only perform research without implementation +- `--provider`: The LLM provider to use (choices: anthropic, openai, openrouter, openai-compatible, gemini) +- `--model`: The model name to use (required for non-Anthropic providers) - `--cowboy-mode`: Skip interactive approval for shell commands -- `--hil, -H`: Enable human-in-the-loop mode, allowing the agent to interactively ask you questions during task execution -- `--provider`: Specify the model provider (See Model Configuration section) -- `--model`: Specify the model name (See Model Configuration section) -- `--expert-provider`: Specify the provider for the expert tool (defaults to OpenAI) -- `--expert-model`: Specify the model name for the expert tool (defaults to o1 for OpenAI) -- `--chat`: Enable chat mode for interactive assistance -- `--verbose`: Enable detailed logging output for debugging and monitoring +- `--expert-provider`: The LLM provider to use for expert knowledge queries (choices: anthropic, openai, openrouter, openai-compatible, gemini) +- `--expert-model`: The model name to use for expert knowledge queries (required for non-OpenAI providers) +- `--hil, -H`: Enable human-in-the-loop mode for interactive assistance during task execution +- `--chat`: Enable chat mode with direct human interaction (implies --hil) +- `--verbose`: Enable verbose logging output +- `--temperature`: LLM temperature (0.0-2.0) to control randomness in responses +- `--disable-limit-tokens`: Disable token limiting for Anthropic Claude react agents +- `--recursion-limit`: Maximum recursion depth for agent operations (default: 100) +- `--version`: Show program version number and exit ### Example Tasks diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 9860b2a..72df89b 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -5,6 +5,7 @@ from datetime import datetime from rich.panel import Panel from rich.console import Console from langgraph.checkpoint.memory import MemorySaver +from ra_aid.config import DEFAULT_RECURSION_LIMIT from ra_aid.env import validate_environment from ra_aid.project_info import ( get_project_info, @@ -121,6 +122,12 @@ Examples: action="store_false", help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.", ) + parser.add_argument( + "--recursion-limit", + type=int, + default=DEFAULT_RECURSION_LIMIT, + help="Maximum recursion depth for agent operations (default: 100)", + ) if args is None: args = sys.argv[1:] @@ -162,6 +169,10 @@ Examples: ): parser.error("Temperature must be between 0.0 and 2.0") + # Validate recursion limit is positive + if parsed_args.recursion_limit <= 0: + parser.error("Recursion limit must be positive") + return parsed_args @@ -255,8 +266,8 @@ def main(): # Run chat agent with CHAT_PROMPT config = { - "configurable": {"thread_id": uuid.uuid4()}, - "recursion_limit": 100, + "configurable": {"thread_id": str(uuid.uuid4())}, + "recursion_limit": args.recursion_limit, "chat_mode": True, "cowboy_mode": args.cowboy_mode, "hil": True, # Always true in chat mode @@ -305,8 +316,8 @@ def main(): base_task = args.message config = { - "configurable": {"thread_id": uuid.uuid4()}, - "recursion_limit": 100, + "configurable": {"thread_id": str(uuid.uuid4())}, + "recursion_limit": args.recursion_limit, "research_only": args.research_only, "cowboy_mode": args.cowboy_mode, "web_research_enabled": web_research_enabled, diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index e1c5479..7d52450 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -10,6 +10,7 @@ import signal from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt.chat_agent_executor import AgentState +from ra_aid.config import DEFAULT_RECURSION_LIMIT from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens from ra_aid.agents.ciayn_agent import CiaynAgent import threading @@ -293,15 +294,12 @@ def run_research_agent( web_research_enabled, ) - # Initialize memory if not provided if memory is None: memory = MemorySaver() - # Set up thread ID if thread_id is None: thread_id = str(uuid.uuid4()) - # Configure tools tools = get_research_tools( research_only=research_only, expert_enabled=expert_enabled, @@ -309,10 +307,8 @@ def run_research_agent( web_research_enabled=config.get("web_research_enabled", False), ) - # Create agent agent = create_agent(model, tools, checkpointer=memory) - # Format prompt sections expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" web_research_section = ( @@ -321,12 +317,10 @@ def run_research_agent( else "" ) - # Get research context from memory key_facts = _global_memory.get("key_facts", "") code_snippets = _global_memory.get("code_snippets", "") related_files = _global_memory.get("related_files", "") - # Get project info try: project_info = get_project_info(".", file_limit=2000) formatted_project_info = format_project_info(project_info) @@ -334,7 +328,6 @@ def run_research_agent( logger.warning(f"Failed to get project info: {e}") formatted_project_info = "" - # Build prompt prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( base_task=base_task_or_query, research_only_note="" @@ -350,13 +343,13 @@ def run_research_agent( project_info=formatted_project_info, ) - # Set up configuration - run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100} + config = _global_memory.get("config", {}) if not config else config + recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) + run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit} if config: run_config.update(config) try: - # Display console message if provided if console_message: console.print( Panel(Markdown(console_message), title="🔬 Looking into it...") @@ -365,12 +358,10 @@ def run_research_agent( if project_info: display_project_status(project_info) - # Run agent with retry logic if available if agent is not None: logger.debug("Research agent completed successfully") return run_agent_with_retry(agent, prompt, run_config) else: - # Just run web research tools directly if no agent logger.debug("No model provided, running web research tools directly") return run_web_research_agent( base_task_or_query, @@ -434,30 +425,23 @@ def run_web_research_agent( web_research_enabled, ) - # Initialize memory if not provided if memory is None: memory = MemorySaver() - # Set up thread ID if thread_id is None: thread_id = str(uuid.uuid4()) - # Configure tools using restricted web research toolset tools = get_web_research_tools(expert_enabled=expert_enabled) - # Create agent agent = create_agent(model, tools, checkpointer=memory) - # Format prompt sections expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" - # Get research context from memory key_facts = _global_memory.get("key_facts", "") code_snippets = _global_memory.get("code_snippets", "") related_files = _global_memory.get("related_files", "") - # Build prompt prompt = WEB_RESEARCH_PROMPT.format( web_research_query=query, expert_section=expert_section, @@ -467,13 +451,13 @@ def run_web_research_agent( related_files=related_files, ) - # Set up configuration - run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100} + config = _global_memory.get("config", {}) if not config else config + recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) + run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit} if config: run_config.update(config) try: - # Display console message if provided if console_message: console.print(Panel(Markdown(console_message), title="🔬 Researching...")) @@ -515,24 +499,19 @@ def run_planning_agent( logger.debug("Starting planning agent with thread_id=%s", thread_id) logger.debug("Planning configuration: expert=%s, hil=%s", expert_enabled, hil) - # Initialize memory if not provided if memory is None: memory = MemorySaver() - # Set up thread ID if thread_id is None: thread_id = str(uuid.uuid4()) - # Configure tools tools = get_planning_tools( expert_enabled=expert_enabled, web_research_enabled=config.get("web_research_enabled", False), ) - # Create agent agent = create_agent(model, tools, checkpointer=memory) - # Format prompt sections expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" web_research_section = ( @@ -541,7 +520,6 @@ def run_planning_agent( else "" ) - # Build prompt planning_prompt = PLANNING_PROMPT.format( expert_section=expert_section, human_section=human_section, @@ -557,8 +535,9 @@ def run_planning_agent( else " Only request implementation if the user explicitly asked for changes to be made.", ) - # Set up configuration - run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100} + config = _global_memory.get("config", {}) if not config else config + recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) + run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit} if config: run_config.update(config) @@ -614,24 +593,19 @@ def run_task_implementation_agent( logger.debug("Task details: base_task=%s, current_task=%s", base_task, task) logger.debug("Related files: %s", related_files) - # Initialize memory if not provided if memory is None: memory = MemorySaver() - # Set up thread ID if thread_id is None: thread_id = str(uuid.uuid4()) - # Configure tools tools = get_implementation_tools( expert_enabled=expert_enabled, web_research_enabled=config.get("web_research_enabled", False), ) - # Create agent agent = create_agent(model, tools, checkpointer=memory) - # Build prompt prompt = IMPLEMENTATION_PROMPT.format( base_task=base_task, task=task, @@ -651,8 +625,9 @@ def run_task_implementation_agent( else "", ) - # Set up configuration - run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100} + config = _global_memory.get("config", {}) if not config else config + recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) + run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit} if config: run_config.update(config) diff --git a/ra_aid/config.py b/ra_aid/config.py index 23a3531..2a2a187 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -1 +1,3 @@ """Configuration utilities.""" + +DEFAULT_RECURSION_LIMIT = 100 diff --git a/tests/ra_aid/test_main.py b/tests/ra_aid/test_main.py new file mode 100644 index 0000000..87293fe --- /dev/null +++ b/tests/ra_aid/test_main.py @@ -0,0 +1,123 @@ +"""Unit tests for __main__.py argument parsing.""" + +import pytest +from ra_aid.__main__ import parse_arguments +from ra_aid.tools.memory import _global_memory +from ra_aid.config import DEFAULT_RECURSION_LIMIT + + +@pytest.fixture +def mock_dependencies(monkeypatch): + """Mock all dependencies needed for main().""" + monkeypatch.setattr('ra_aid.__main__.check_dependencies', lambda: None) + + monkeypatch.setattr('ra_aid.__main__.validate_environment', + lambda args: (True, [], True, [])) + + def mock_config_update(*args, **kwargs): + config = _global_memory.get("config", {}) + if kwargs.get("temperature"): + config["temperature"] = kwargs["temperature"] + _global_memory["config"] = config + return None + + monkeypatch.setattr('ra_aid.__main__.initialize_llm', + mock_config_update) + + monkeypatch.setattr('ra_aid.__main__.run_research_agent', + lambda *args, **kwargs: None) + +def test_recursion_limit_in_global_config(mock_dependencies): + """Test that recursion limit is correctly set in global config.""" + from ra_aid.__main__ import main + import sys + from unittest.mock import patch + + _global_memory.clear() + + with patch.object(sys, 'argv', ['ra-aid', '-m', 'test message']): + main() + assert _global_memory["config"]["recursion_limit"] == DEFAULT_RECURSION_LIMIT + + _global_memory.clear() + + with patch.object(sys, 'argv', ['ra-aid', '-m', 'test message', '--recursion-limit', '50']): + main() + assert _global_memory["config"]["recursion_limit"] == 50 + + +def test_negative_recursion_limit(): + """Test that negative recursion limit raises error.""" + with pytest.raises(SystemExit): + parse_arguments(["-m", "test message", "--recursion-limit", "-1"]) + + +def test_zero_recursion_limit(): + """Test that zero recursion limit raises error.""" + with pytest.raises(SystemExit): + parse_arguments(["-m", "test message", "--recursion-limit", "0"]) + + +def test_config_settings(mock_dependencies): + """Test that various settings are correctly applied in global config.""" + from ra_aid.__main__ import main + import sys + from unittest.mock import patch + + _global_memory.clear() + + with patch.object(sys, 'argv', [ + 'ra-aid', '-m', 'test message', + '--cowboy-mode', + '--research-only', + '--provider', 'anthropic', + '--model', 'claude-3-5-sonnet-20241022', + '--expert-provider', 'openai', + '--expert-model', 'gpt-4', + '--temperature', '0.7', + '--disable-limit-tokens' + ]): + main() + config = _global_memory["config"] + assert config["cowboy_mode"] is True + assert config["research_only"] is True + assert config["provider"] == "anthropic" + assert config["model"] == "claude-3-5-sonnet-20241022" + assert config["expert_provider"] == "openai" + assert config["expert_model"] == "gpt-4" + assert config["limit_tokens"] is False + + +def test_temperature_validation(mock_dependencies): + """Test that temperature argument is correctly passed to initialize_llm.""" + from ra_aid.__main__ import main, initialize_llm + import sys + from unittest.mock import patch + + _global_memory.clear() + + with patch('ra_aid.__main__.initialize_llm') as mock_init_llm: + with patch.object(sys, 'argv', ['ra-aid', '-m', 'test', '--temperature', '0.7']): + main() + mock_init_llm.assert_called_once() + assert mock_init_llm.call_args.kwargs['temperature'] == 0.7 + + with pytest.raises(SystemExit): + with patch.object(sys, 'argv', ['ra-aid', '-m', 'test', '--temperature', '2.1']): + main() + + +def test_missing_message(): + """Test that missing message argument raises error.""" + # Test chat mode which doesn't require message + args = parse_arguments(["--chat"]) + assert args.chat is True + assert args.message is None + + # Test non-chat mode requires message + args = parse_arguments(["--provider", "openai"]) + assert args.message is None + + # Verify message is captured when provided + args = parse_arguments(["-m", "test"]) + assert args.message == "test"