Add configurable --recursion-limit argument (#46)
* test: Add unit tests for argument parsing in __main__.py * test: Update tests to remove invalid argument and improve error handling * test: Fix test_missing_message to handle missing argument cases correctly * test: Fix test_missing_message to reflect argument parsing behavior * test: Combine recursion limit tests and verify global config updates * fix: Include recursion_limit in config for recursion limit tests * test: Mock dependencies and validate recursion limit in global config * test: Remove commented-out code and clean up test_main.py * test: Remove self-evident comments and improve test assertions in test_main.py * fix: Mock user input and handle temperature in global config tests * fix: Fix test failures by correcting mock targets and handling temperature * test: Update temperature validation to check argument passing to initialize_llm * fix: Correct mock for ask_human and access kwargs in temperature test * fix: Patch the entire ask_human function in test_chat_mode_implies_hil * docs: Add recursion limit option to README documentation * docs: Update README.md with all available command line arguments * feat(config): add DEFAULT_RECURSION_LIMIT constant to set default recursion depth feat(main.py): add --recursion-limit argument to configure maximum recursion depth for agent operations fix(main.py): validate that recursion limit is positive before processing refactor(main.py): use args.recursion_limit in agent configuration instead of hardcoded value refactor(agent_utils.py): update agent configuration to use recursion limit from global memory or default value refactor(run_research_agent): clean up comments and improve readability refactor(run_web_research_agent): clean up comments and improve readability refactor(run_planning_agent): clean up comments and improve readability refactor(run_task_implementation_agent): clean up comments and improve readability delete(test_main.py): remove obsolete test for chat mode and HIL configuration
This commit is contained in:
parent
e886d98c0e
commit
46e7340ddb
20
README.md
20
README.md
|
|
@ -162,16 +162,20 @@ ra-aid -m "Add new feature" --verbose
|
||||||
|
|
||||||
### Command Line Options
|
### Command Line Options
|
||||||
|
|
||||||
- `-m, --message`: The task or query to be executed (required)
|
- `-m, --message`: The task or query to be executed (required except in chat mode)
|
||||||
- `--research-only`: Only perform research without implementation
|
- `--research-only`: Only perform research without implementation
|
||||||
|
- `--provider`: The LLM provider to use (choices: anthropic, openai, openrouter, openai-compatible, gemini)
|
||||||
|
- `--model`: The model name to use (required for non-Anthropic providers)
|
||||||
- `--cowboy-mode`: Skip interactive approval for shell commands
|
- `--cowboy-mode`: Skip interactive approval for shell commands
|
||||||
- `--hil, -H`: Enable human-in-the-loop mode, allowing the agent to interactively ask you questions during task execution
|
- `--expert-provider`: The LLM provider to use for expert knowledge queries (choices: anthropic, openai, openrouter, openai-compatible, gemini)
|
||||||
- `--provider`: Specify the model provider (See Model Configuration section)
|
- `--expert-model`: The model name to use for expert knowledge queries (required for non-OpenAI providers)
|
||||||
- `--model`: Specify the model name (See Model Configuration section)
|
- `--hil, -H`: Enable human-in-the-loop mode for interactive assistance during task execution
|
||||||
- `--expert-provider`: Specify the provider for the expert tool (defaults to OpenAI)
|
- `--chat`: Enable chat mode with direct human interaction (implies --hil)
|
||||||
- `--expert-model`: Specify the model name for the expert tool (defaults to o1 for OpenAI)
|
- `--verbose`: Enable verbose logging output
|
||||||
- `--chat`: Enable chat mode for interactive assistance
|
- `--temperature`: LLM temperature (0.0-2.0) to control randomness in responses
|
||||||
- `--verbose`: Enable detailed logging output for debugging and monitoring
|
- `--disable-limit-tokens`: Disable token limiting for Anthropic Claude react agents
|
||||||
|
- `--recursion-limit`: Maximum recursion depth for agent operations (default: 100)
|
||||||
|
- `--version`: Show program version number and exit
|
||||||
|
|
||||||
### Example Tasks
|
### Example Tasks
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from datetime import datetime
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
||||||
from ra_aid.env import validate_environment
|
from ra_aid.env import validate_environment
|
||||||
from ra_aid.project_info import (
|
from ra_aid.project_info import (
|
||||||
get_project_info,
|
get_project_info,
|
||||||
|
|
@ -121,6 +122,12 @@ Examples:
|
||||||
action="store_false",
|
action="store_false",
|
||||||
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
|
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recursion-limit",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_RECURSION_LIMIT,
|
||||||
|
help="Maximum recursion depth for agent operations (default: 100)",
|
||||||
|
)
|
||||||
|
|
||||||
if args is None:
|
if args is None:
|
||||||
args = sys.argv[1:]
|
args = sys.argv[1:]
|
||||||
|
|
@ -162,6 +169,10 @@ Examples:
|
||||||
):
|
):
|
||||||
parser.error("Temperature must be between 0.0 and 2.0")
|
parser.error("Temperature must be between 0.0 and 2.0")
|
||||||
|
|
||||||
|
# Validate recursion limit is positive
|
||||||
|
if parsed_args.recursion_limit <= 0:
|
||||||
|
parser.error("Recursion limit must be positive")
|
||||||
|
|
||||||
return parsed_args
|
return parsed_args
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -255,8 +266,8 @@ def main():
|
||||||
|
|
||||||
# Run chat agent with CHAT_PROMPT
|
# Run chat agent with CHAT_PROMPT
|
||||||
config = {
|
config = {
|
||||||
"configurable": {"thread_id": uuid.uuid4()},
|
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||||
"recursion_limit": 100,
|
"recursion_limit": args.recursion_limit,
|
||||||
"chat_mode": True,
|
"chat_mode": True,
|
||||||
"cowboy_mode": args.cowboy_mode,
|
"cowboy_mode": args.cowboy_mode,
|
||||||
"hil": True, # Always true in chat mode
|
"hil": True, # Always true in chat mode
|
||||||
|
|
@ -305,8 +316,8 @@ def main():
|
||||||
|
|
||||||
base_task = args.message
|
base_task = args.message
|
||||||
config = {
|
config = {
|
||||||
"configurable": {"thread_id": uuid.uuid4()},
|
"configurable": {"thread_id": str(uuid.uuid4())},
|
||||||
"recursion_limit": 100,
|
"recursion_limit": args.recursion_limit,
|
||||||
"research_only": args.research_only,
|
"research_only": args.research_only,
|
||||||
"cowboy_mode": args.cowboy_mode,
|
"cowboy_mode": args.cowboy_mode,
|
||||||
"web_research_enabled": web_research_enabled,
|
"web_research_enabled": web_research_enabled,
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import signal
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
|
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
||||||
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens
|
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
import threading
|
import threading
|
||||||
|
|
@ -293,15 +294,12 @@ def run_research_agent(
|
||||||
web_research_enabled,
|
web_research_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize memory if not provided
|
|
||||||
if memory is None:
|
if memory is None:
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
|
|
||||||
# Set up thread ID
|
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
thread_id = str(uuid.uuid4())
|
thread_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Configure tools
|
|
||||||
tools = get_research_tools(
|
tools = get_research_tools(
|
||||||
research_only=research_only,
|
research_only=research_only,
|
||||||
expert_enabled=expert_enabled,
|
expert_enabled=expert_enabled,
|
||||||
|
|
@ -309,10 +307,8 @@ def run_research_agent(
|
||||||
web_research_enabled=config.get("web_research_enabled", False),
|
web_research_enabled=config.get("web_research_enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create agent
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
# Format prompt sections
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||||
web_research_section = (
|
web_research_section = (
|
||||||
|
|
@ -321,12 +317,10 @@ def run_research_agent(
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get research context from memory
|
|
||||||
key_facts = _global_memory.get("key_facts", "")
|
key_facts = _global_memory.get("key_facts", "")
|
||||||
code_snippets = _global_memory.get("code_snippets", "")
|
code_snippets = _global_memory.get("code_snippets", "")
|
||||||
related_files = _global_memory.get("related_files", "")
|
related_files = _global_memory.get("related_files", "")
|
||||||
|
|
||||||
# Get project info
|
|
||||||
try:
|
try:
|
||||||
project_info = get_project_info(".", file_limit=2000)
|
project_info = get_project_info(".", file_limit=2000)
|
||||||
formatted_project_info = format_project_info(project_info)
|
formatted_project_info = format_project_info(project_info)
|
||||||
|
|
@ -334,7 +328,6 @@ def run_research_agent(
|
||||||
logger.warning(f"Failed to get project info: {e}")
|
logger.warning(f"Failed to get project info: {e}")
|
||||||
formatted_project_info = ""
|
formatted_project_info = ""
|
||||||
|
|
||||||
# Build prompt
|
|
||||||
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
||||||
base_task=base_task_or_query,
|
base_task=base_task_or_query,
|
||||||
research_only_note=""
|
research_only_note=""
|
||||||
|
|
@ -350,13 +343,13 @@ def run_research_agent(
|
||||||
project_info=formatted_project_info,
|
project_info=formatted_project_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up configuration
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
|
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
||||||
if config:
|
if config:
|
||||||
run_config.update(config)
|
run_config.update(config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Display console message if provided
|
|
||||||
if console_message:
|
if console_message:
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(console_message), title="🔬 Looking into it...")
|
Panel(Markdown(console_message), title="🔬 Looking into it...")
|
||||||
|
|
@ -365,12 +358,10 @@ def run_research_agent(
|
||||||
if project_info:
|
if project_info:
|
||||||
display_project_status(project_info)
|
display_project_status(project_info)
|
||||||
|
|
||||||
# Run agent with retry logic if available
|
|
||||||
if agent is not None:
|
if agent is not None:
|
||||||
logger.debug("Research agent completed successfully")
|
logger.debug("Research agent completed successfully")
|
||||||
return run_agent_with_retry(agent, prompt, run_config)
|
return run_agent_with_retry(agent, prompt, run_config)
|
||||||
else:
|
else:
|
||||||
# Just run web research tools directly if no agent
|
|
||||||
logger.debug("No model provided, running web research tools directly")
|
logger.debug("No model provided, running web research tools directly")
|
||||||
return run_web_research_agent(
|
return run_web_research_agent(
|
||||||
base_task_or_query,
|
base_task_or_query,
|
||||||
|
|
@ -434,30 +425,23 @@ def run_web_research_agent(
|
||||||
web_research_enabled,
|
web_research_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize memory if not provided
|
|
||||||
if memory is None:
|
if memory is None:
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
|
|
||||||
# Set up thread ID
|
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
thread_id = str(uuid.uuid4())
|
thread_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Configure tools using restricted web research toolset
|
|
||||||
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
tools = get_web_research_tools(expert_enabled=expert_enabled)
|
||||||
|
|
||||||
# Create agent
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
# Format prompt sections
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||||
|
|
||||||
# Get research context from memory
|
|
||||||
key_facts = _global_memory.get("key_facts", "")
|
key_facts = _global_memory.get("key_facts", "")
|
||||||
code_snippets = _global_memory.get("code_snippets", "")
|
code_snippets = _global_memory.get("code_snippets", "")
|
||||||
related_files = _global_memory.get("related_files", "")
|
related_files = _global_memory.get("related_files", "")
|
||||||
|
|
||||||
# Build prompt
|
|
||||||
prompt = WEB_RESEARCH_PROMPT.format(
|
prompt = WEB_RESEARCH_PROMPT.format(
|
||||||
web_research_query=query,
|
web_research_query=query,
|
||||||
expert_section=expert_section,
|
expert_section=expert_section,
|
||||||
|
|
@ -467,13 +451,13 @@ def run_web_research_agent(
|
||||||
related_files=related_files,
|
related_files=related_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up configuration
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
|
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
||||||
if config:
|
if config:
|
||||||
run_config.update(config)
|
run_config.update(config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Display console message if provided
|
|
||||||
if console_message:
|
if console_message:
|
||||||
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
|
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
|
||||||
|
|
||||||
|
|
@ -515,24 +499,19 @@ def run_planning_agent(
|
||||||
logger.debug("Starting planning agent with thread_id=%s", thread_id)
|
logger.debug("Starting planning agent with thread_id=%s", thread_id)
|
||||||
logger.debug("Planning configuration: expert=%s, hil=%s", expert_enabled, hil)
|
logger.debug("Planning configuration: expert=%s, hil=%s", expert_enabled, hil)
|
||||||
|
|
||||||
# Initialize memory if not provided
|
|
||||||
if memory is None:
|
if memory is None:
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
|
|
||||||
# Set up thread ID
|
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
thread_id = str(uuid.uuid4())
|
thread_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Configure tools
|
|
||||||
tools = get_planning_tools(
|
tools = get_planning_tools(
|
||||||
expert_enabled=expert_enabled,
|
expert_enabled=expert_enabled,
|
||||||
web_research_enabled=config.get("web_research_enabled", False),
|
web_research_enabled=config.get("web_research_enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create agent
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
# Format prompt sections
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||||
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
||||||
web_research_section = (
|
web_research_section = (
|
||||||
|
|
@ -541,7 +520,6 @@ def run_planning_agent(
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build prompt
|
|
||||||
planning_prompt = PLANNING_PROMPT.format(
|
planning_prompt = PLANNING_PROMPT.format(
|
||||||
expert_section=expert_section,
|
expert_section=expert_section,
|
||||||
human_section=human_section,
|
human_section=human_section,
|
||||||
|
|
@ -557,8 +535,9 @@ def run_planning_agent(
|
||||||
else " Only request implementation if the user explicitly asked for changes to be made.",
|
else " Only request implementation if the user explicitly asked for changes to be made.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up configuration
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
|
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
||||||
if config:
|
if config:
|
||||||
run_config.update(config)
|
run_config.update(config)
|
||||||
|
|
||||||
|
|
@ -614,24 +593,19 @@ def run_task_implementation_agent(
|
||||||
logger.debug("Task details: base_task=%s, current_task=%s", base_task, task)
|
logger.debug("Task details: base_task=%s, current_task=%s", base_task, task)
|
||||||
logger.debug("Related files: %s", related_files)
|
logger.debug("Related files: %s", related_files)
|
||||||
|
|
||||||
# Initialize memory if not provided
|
|
||||||
if memory is None:
|
if memory is None:
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
|
|
||||||
# Set up thread ID
|
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
thread_id = str(uuid.uuid4())
|
thread_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Configure tools
|
|
||||||
tools = get_implementation_tools(
|
tools = get_implementation_tools(
|
||||||
expert_enabled=expert_enabled,
|
expert_enabled=expert_enabled,
|
||||||
web_research_enabled=config.get("web_research_enabled", False),
|
web_research_enabled=config.get("web_research_enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create agent
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory)
|
agent = create_agent(model, tools, checkpointer=memory)
|
||||||
|
|
||||||
# Build prompt
|
|
||||||
prompt = IMPLEMENTATION_PROMPT.format(
|
prompt = IMPLEMENTATION_PROMPT.format(
|
||||||
base_task=base_task,
|
base_task=base_task,
|
||||||
task=task,
|
task=task,
|
||||||
|
|
@ -651,8 +625,9 @@ def run_task_implementation_agent(
|
||||||
else "",
|
else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up configuration
|
config = _global_memory.get("config", {}) if not config else config
|
||||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
|
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
||||||
if config:
|
if config:
|
||||||
run_config.update(config)
|
run_config.update(config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1 +1,3 @@
|
||||||
"""Configuration utilities."""
|
"""Configuration utilities."""
|
||||||
|
|
||||||
|
DEFAULT_RECURSION_LIMIT = 100
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,123 @@
|
||||||
|
"""Unit tests for __main__.py argument parsing."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from ra_aid.__main__ import parse_arguments
|
||||||
|
from ra_aid.tools.memory import _global_memory
|
||||||
|
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dependencies(monkeypatch):
|
||||||
|
"""Mock all dependencies needed for main()."""
|
||||||
|
monkeypatch.setattr('ra_aid.__main__.check_dependencies', lambda: None)
|
||||||
|
|
||||||
|
monkeypatch.setattr('ra_aid.__main__.validate_environment',
|
||||||
|
lambda args: (True, [], True, []))
|
||||||
|
|
||||||
|
def mock_config_update(*args, **kwargs):
|
||||||
|
config = _global_memory.get("config", {})
|
||||||
|
if kwargs.get("temperature"):
|
||||||
|
config["temperature"] = kwargs["temperature"]
|
||||||
|
_global_memory["config"] = config
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr('ra_aid.__main__.initialize_llm',
|
||||||
|
mock_config_update)
|
||||||
|
|
||||||
|
monkeypatch.setattr('ra_aid.__main__.run_research_agent',
|
||||||
|
lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
def test_recursion_limit_in_global_config(mock_dependencies):
|
||||||
|
"""Test that recursion limit is correctly set in global config."""
|
||||||
|
from ra_aid.__main__ import main
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
_global_memory.clear()
|
||||||
|
|
||||||
|
with patch.object(sys, 'argv', ['ra-aid', '-m', 'test message']):
|
||||||
|
main()
|
||||||
|
assert _global_memory["config"]["recursion_limit"] == DEFAULT_RECURSION_LIMIT
|
||||||
|
|
||||||
|
_global_memory.clear()
|
||||||
|
|
||||||
|
with patch.object(sys, 'argv', ['ra-aid', '-m', 'test message', '--recursion-limit', '50']):
|
||||||
|
main()
|
||||||
|
assert _global_memory["config"]["recursion_limit"] == 50
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_recursion_limit():
|
||||||
|
"""Test that negative recursion limit raises error."""
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
parse_arguments(["-m", "test message", "--recursion-limit", "-1"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_recursion_limit():
|
||||||
|
"""Test that zero recursion limit raises error."""
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
parse_arguments(["-m", "test message", "--recursion-limit", "0"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_settings(mock_dependencies):
|
||||||
|
"""Test that various settings are correctly applied in global config."""
|
||||||
|
from ra_aid.__main__ import main
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
_global_memory.clear()
|
||||||
|
|
||||||
|
with patch.object(sys, 'argv', [
|
||||||
|
'ra-aid', '-m', 'test message',
|
||||||
|
'--cowboy-mode',
|
||||||
|
'--research-only',
|
||||||
|
'--provider', 'anthropic',
|
||||||
|
'--model', 'claude-3-5-sonnet-20241022',
|
||||||
|
'--expert-provider', 'openai',
|
||||||
|
'--expert-model', 'gpt-4',
|
||||||
|
'--temperature', '0.7',
|
||||||
|
'--disable-limit-tokens'
|
||||||
|
]):
|
||||||
|
main()
|
||||||
|
config = _global_memory["config"]
|
||||||
|
assert config["cowboy_mode"] is True
|
||||||
|
assert config["research_only"] is True
|
||||||
|
assert config["provider"] == "anthropic"
|
||||||
|
assert config["model"] == "claude-3-5-sonnet-20241022"
|
||||||
|
assert config["expert_provider"] == "openai"
|
||||||
|
assert config["expert_model"] == "gpt-4"
|
||||||
|
assert config["limit_tokens"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_temperature_validation(mock_dependencies):
|
||||||
|
"""Test that temperature argument is correctly passed to initialize_llm."""
|
||||||
|
from ra_aid.__main__ import main, initialize_llm
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
_global_memory.clear()
|
||||||
|
|
||||||
|
with patch('ra_aid.__main__.initialize_llm') as mock_init_llm:
|
||||||
|
with patch.object(sys, 'argv', ['ra-aid', '-m', 'test', '--temperature', '0.7']):
|
||||||
|
main()
|
||||||
|
mock_init_llm.assert_called_once()
|
||||||
|
assert mock_init_llm.call_args.kwargs['temperature'] == 0.7
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
with patch.object(sys, 'argv', ['ra-aid', '-m', 'test', '--temperature', '2.1']):
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_message():
|
||||||
|
"""Test that missing message argument raises error."""
|
||||||
|
# Test chat mode which doesn't require message
|
||||||
|
args = parse_arguments(["--chat"])
|
||||||
|
assert args.chat is True
|
||||||
|
assert args.message is None
|
||||||
|
|
||||||
|
# Test non-chat mode requires message
|
||||||
|
args = parse_arguments(["--provider", "openai"])
|
||||||
|
assert args.message is None
|
||||||
|
|
||||||
|
# Verify message is captured when provided
|
||||||
|
args = parse_arguments(["-m", "test"])
|
||||||
|
assert args.message == "test"
|
||||||
Loading…
Reference in New Issue