config repository

This commit is contained in:
AI Christianson 2025-03-04 21:01:08 -05:00
parent 3e68dd3fa6
commit 5bd8c76a22
19 changed files with 695 additions and 290 deletions

View File

@ -58,6 +58,10 @@ from ra_aid.database.repositories.related_files_repository import (
from ra_aid.database.repositories.work_log_repository import (
WorkLogRepositoryManager
)
from ra_aid.database.repositories.config_repository import (
ConfigRepositoryManager,
get_config_repository
)
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.console.output import cpm
@ -77,7 +81,7 @@ from ra_aid.prompts.chat_prompts import CHAT_PROMPT
from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_CHAT
from ra_aid.tool_configs import get_chat_tools, set_modification_tools
from ra_aid.tools.human import ask_human
from ra_aid.tools.memory import _global_memory, get_memory_value
from ra_aid.tools.memory import get_memory_value
logger = get_logger(__name__)
@ -338,7 +342,7 @@ implementation_memory = MemorySaver()
def is_informational_query() -> bool:
"""Determine if the current query is informational based on config settings."""
return _global_memory.get("config", {}).get("research_only", False)
return get_config_repository().get("research_only", False)
def is_stage_requested(stage: str) -> bool:
@ -404,13 +408,17 @@ def main():
except Exception as e:
logger.error(f"Database migration error: {str(e)}")
# Initialize empty config dictionary to be populated later
config = {}
# Initialize repositories with database connection
with KeyFactRepositoryManager(db) as key_fact_repo, \
KeySnippetRepositoryManager(db) as key_snippet_repo, \
HumanInputRepositoryManager(db) as human_input_repo, \
ResearchNoteRepositoryManager(db) as research_note_repo, \
RelatedFilesRepositoryManager() as related_files_repo, \
WorkLogRepositoryManager() as work_log_repo:
WorkLogRepositoryManager() as work_log_repo, \
ConfigRepositoryManager(config) as config_repo:
# This initializes all repositories and makes them available via their respective get methods
logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository")
@ -418,6 +426,7 @@ def main():
logger.debug("Initialized ResearchNoteRepository")
logger.debug("Initialized RelatedFilesRepository")
logger.debug("Initialized WorkLogRepository")
logger.debug("Initialized ConfigRepository")
# Check dependencies before proceeding
check_dependencies()
@ -520,13 +529,13 @@ def main():
"limit_tokens": args.disable_limit_tokens,
}
# Store config in global memory
_global_memory["config"] = config
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
_global_memory["config"]["temperature"] = args.temperature
# Store config in repository
config_repo.update(config)
config_repo.set("provider", args.provider)
config_repo.set("model", args.model)
config_repo.set("expert_provider", args.expert_provider)
config_repo.set("expert_model", args.expert_model)
config_repo.set("temperature", args.temperature)
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
@ -594,33 +603,27 @@ def main():
"test_cmd_timeout": args.test_cmd_timeout,
}
# Store config in global memory for access by is_informational_query
_global_memory["config"] = config
# Store config in repository
config_repo.update(config)
# Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model
config_repo.set("provider", args.provider)
config_repo.set("model", args.model)
# Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model
config_repo.set("expert_provider", args.expert_provider)
config_repo.set("expert_model", args.expert_model)
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
config_repo.set("planner_provider", args.planner_provider or args.provider)
config_repo.set("planner_model", args.planner_model or args.model)
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = (
args.research_model or args.model
)
config_repo.set("research_provider", args.research_provider or args.provider)
config_repo.set("research_model", args.research_model or args.model)
# Store temperature in global config
_global_memory["config"]["temperature"] = args.temperature
# Store temperature in config
config_repo.set("temperature", args.temperature)
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)

View File

@ -94,11 +94,11 @@ from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
from ra_aid.tools.memory import (
_global_memory,
get_memory_value,
get_related_files,
log_work_event,
)
from ra_aid.database.repositories.config_repository import get_config_repository
console = Console()
@ -302,7 +302,7 @@ def create_agent(
config['limit_tokens'] = False.
"""
try:
config = _global_memory.get("config", {})
config = get_config_repository().get_all()
max_input_tokens = (
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
)
@ -319,7 +319,7 @@ def create_agent(
except Exception as e:
# Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = _global_memory.get("config", {})
config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
return create_react_agent(model, tools, **agent_kwargs)
@ -443,7 +443,7 @@ def run_research_agent(
new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "",
)
config = _global_memory.get("config", {}) if not config else config
config = get_config_repository().get_all() if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {
"configurable": {"thread_id": thread_id},
@ -575,7 +575,7 @@ def run_web_research_agent(
related_files=related_files,
)
config = _global_memory.get("config", {}) if not config else config
config = get_config_repository().get_all() if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {
@ -709,7 +709,7 @@ def run_planning_agent(
),
)
config = _global_memory.get("config", {}) if not config else config
config = get_config_repository().get_all() if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {
"configurable": {"thread_id": thread_id},
@ -824,7 +824,7 @@ def run_task_implementation_agent(
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
human_section=(
HUMAN_PROMPT_SECTION_IMPLEMENTATION
if _global_memory.get("config", {}).get("hil", False)
if get_config_repository().get("hil", False)
else ""
),
web_research_section=(
@ -834,7 +834,7 @@ def run_task_implementation_agent(
),
)
config = _global_memory.get("config", {}) if not config else config
config = get_config_repository().get_all() if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {
"configurable": {"thread_id": thread_id},

View File

@ -19,9 +19,10 @@ logger = logging.getLogger(__name__)
from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
from ra_aid.tools.memory import log_work_event, _global_memory
from ra_aid.tools.memory import log_work_event
console = Console()
@ -149,7 +150,7 @@ def run_key_facts_gc_agent() -> None:
formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()])
# Retrieve configuration
llm_config = _global_memory.get("config", {})
llm_config = get_config_repository().get_all()
# Initialize the LLM model
model = initialize_llm(

View File

@ -16,9 +16,10 @@ from rich.panel import Panel
from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
from ra_aid.tools.memory import log_work_event, _global_memory
from ra_aid.tools.memory import log_work_event
console = Console()
@ -153,7 +154,7 @@ def run_key_snippets_gc_agent() -> None:
])
# Retrieve configuration
llm_config = _global_memory.get("config", {})
llm_config = get_config_repository().get_all()
# Initialize the LLM model
model = initialize_llm(

View File

@ -19,9 +19,10 @@ logger = logging.getLogger(__name__)
from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.llm import initialize_llm
from ra_aid.model_formatters.research_notes_formatter import format_research_note
from ra_aid.tools.memory import log_work_event, _global_memory
from ra_aid.tools.memory import log_work_event
console = Console()
@ -154,7 +155,7 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None:
formatted_notes = "\n".join([f"Note #{k}: {v}" for k, v in notes_dict.items()])
# Retrieve configuration
llm_config = _global_memory.get("config", {})
llm_config = get_config_repository().get_all()
# Initialize the LLM model
model = initialize_llm(

View File

@ -0,0 +1,165 @@
"""Repository for managing configuration values."""
import contextvars
from typing import Any, Dict, Optional
# Create contextvar to hold the ConfigRepository instance
config_repo_var = contextvars.ContextVar("config_repo", default=None)
class ConfigRepository:
"""
Repository for managing configuration values in memory.
This class provides methods to get, set, update, and retrieve all configuration values.
It does not require database models and operates entirely in memory.
"""
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
"""
Initialize the ConfigRepository.
Args:
initial_config: Optional dictionary of initial configuration values
"""
self._config: Dict[str, Any] = {}
# Initialize with default values from config.py
from ra_aid.config import (
DEFAULT_RECURSION_LIMIT,
DEFAULT_MAX_TEST_CMD_RETRIES,
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
DEFAULT_TEST_CMD_TIMEOUT,
VALID_PROVIDERS,
)
self._config = {
"recursion_limit": DEFAULT_RECURSION_LIMIT,
"max_test_cmd_retries": DEFAULT_MAX_TEST_CMD_RETRIES,
"max_tool_failures": DEFAULT_MAX_TOOL_FAILURES,
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
"retry_fallback_count": RETRY_FALLBACK_COUNT,
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
"valid_providers": VALID_PROVIDERS,
}
# Update with any provided initial configuration
if initial_config:
self._config.update(initial_config)
def get(self, key: str, default: Any = None) -> Any:
"""
Get a configuration value by key.
Args:
key: Configuration key to retrieve
default: Default value to return if key is not found
Returns:
The configuration value or default if not found
"""
return self._config.get(key, default)
def set(self, key: str, value: Any) -> None:
"""
Set a configuration value by key.
Args:
key: Configuration key to set
value: Value to set for the key
"""
self._config[key] = value
def update(self, config_dict: Dict[str, Any]) -> None:
"""
Update multiple configuration values at once.
Args:
config_dict: Dictionary of configuration key-value pairs to update
"""
self._config.update(config_dict)
def get_all(self) -> Dict[str, Any]:
"""
Get all configuration values.
Returns:
Dictionary containing all configuration values
"""
return self._config.copy()
class ConfigRepositoryManager:
"""
Context manager for ConfigRepository.
This class provides a context manager interface for ConfigRepository,
using the contextvars approach for thread safety.
Example:
with ConfigRepositoryManager() as repo:
# Use the repository
value = repo.get("key")
repo.set("key", new_value)
"""
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
"""
Initialize the ConfigRepositoryManager.
Args:
initial_config: Optional dictionary of initial configuration values
"""
self.initial_config = initial_config
def __enter__(self) -> 'ConfigRepository':
"""
Initialize the ConfigRepository and return it.
Returns:
ConfigRepository: The initialized repository
"""
repo = ConfigRepository(self.initial_config)
config_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
config_repo_var.set(None)
# Don't suppress exceptions
return False
def get_config_repository() -> ConfigRepository:
"""
Get the current ConfigRepository instance.
Returns:
ConfigRepository: The current repository instance
Raises:
RuntimeError: If no repository is set in the current context
"""
repo = config_repo_var.get()
if repo is None:
raise RuntimeError(
"ConfigRepository not initialized in current context. "
"Make sure to use ConfigRepositoryManager."
)
return repo

View File

@ -27,6 +27,7 @@ from ra_aid.tools.agent import (
request_web_research,
)
from ra_aid.tools.memory import plan_implementation_completed
from ra_aid.database.repositories.config_repository import get_config_repository
def set_modification_tools(use_aider=False):
@ -98,13 +99,11 @@ def get_all_tools() -> list[BaseTool]:
# Define constant tool groups
# Get config from global memory for use_aider value
# Get config from repository for use_aider value
_config = {}
try:
from ra_aid.tools.memory import _global_memory
_config = _global_memory.get("config", {})
except ImportError:
_config = get_config_repository().get_all()
except (ImportError, RuntimeError):
pass
READ_ONLY_TOOLS = get_read_only_tools(use_aider=_config.get("use_aider", False))
@ -139,10 +138,8 @@ def get_research_tools(
# Get config for use_aider value
use_aider = False
try:
from ra_aid.tools.memory import _global_memory
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
use_aider = get_config_repository().get("use_aider", False)
except (ImportError, RuntimeError):
pass
# Start with read-only tools
@ -180,10 +177,8 @@ def get_planning_tools(
# Get config for use_aider value
use_aider = False
try:
from ra_aid.tools.memory import _global_memory
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
use_aider = get_config_repository().get("use_aider", False)
except (ImportError, RuntimeError):
pass
# Start with read-only tools
@ -219,10 +214,8 @@ def get_implementation_tools(
# Get config for use_aider value
use_aider = False
try:
from ra_aid.tools.memory import _global_memory
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
use_aider = get_config_repository().get("use_aider", False)
except (ImportError, RuntimeError):
pass
# Start with read-only tools
@ -285,4 +278,4 @@ def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = Fal
if web_research_enabled:
tools.append(request_web_research)
return tools
return tools

View File

@ -18,13 +18,13 @@ from ra_aid.console.formatting import print_error
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.exceptions import AgentInterrupt
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
from ra_aid.tools.memory import _global_memory
from ..console import print_task_header
from ..llm import initialize_llm
@ -52,7 +52,7 @@ def request_research(query: str) -> ResearchResult:
query: The research question or project description
"""
# Initialize model from config
config = _global_memory.get("config", {})
config = get_config_repository().get_all()
model = initialize_llm(
config.get("provider", "anthropic"),
config.get("model", "claude-3-7-sonnet-20250219"),
@ -165,7 +165,7 @@ def request_web_research(query: str) -> ResearchResult:
query: The research question or project description
"""
# Initialize model from config
config = _global_memory.get("config", {})
config = get_config_repository().get_all()
model = initialize_llm(
config.get("provider", "anthropic"),
config.get("model", "claude-3-7-sonnet-20250219"),
@ -246,7 +246,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
query: The research question or project description
"""
# Initialize model from config
config = _global_memory.get("config", {})
config = get_config_repository().get_all()
model = initialize_llm(
config.get("provider", "anthropic"),
config.get("model", "claude-3-7-sonnet-20250219"),
@ -335,7 +335,7 @@ def request_task_implementation(task_spec: str) -> str:
task_spec: REQUIRED The full task specification (markdown format, typically one part of the overall plan)
"""
# Initialize model from config
config = _global_memory.get("config", {})
config = get_config_repository().get_all()
model = initialize_llm(
config.get("provider", "anthropic"),
config.get("model", "claude-3-5-sonnet-20241022"),
@ -474,7 +474,7 @@ def request_implementation(task_spec: str) -> str:
task_spec: The task specification to plan implementation for
"""
# Initialize model from config
config = _global_memory.get("config", {})
config = get_config_repository().get_all()
model = initialize_llm(
config.get("provider", "anthropic"),
config.get("model", "claude-3-5-sonnet-20241022"),

View File

@ -13,11 +13,12 @@ from ..database.repositories.key_fact_repository import get_key_fact_repository
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
from ..database.repositories.related_files_repository import get_related_files_repository
from ..database.repositories.research_note_repository import get_research_note_repository
from ..database.repositories.config_repository import get_config_repository
from ..llm import initialize_expert_llm
from ..model_formatters import format_key_facts_dict
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
from ..model_formatters.research_notes_formatter import format_research_notes_dict
from .memory import _global_memory, get_memory_value
from .memory import get_memory_value
console = Console()
_model = None
@ -27,9 +28,9 @@ def get_model():
global _model
try:
if _model is None:
config = _global_memory["config"]
provider = config.get("expert_provider") or config.get("provider")
model = config.get("expert_model") or config.get("model")
config_repo = get_config_repository()
provider = config_repo.get("expert_provider") or config_repo.get("provider")
model = config_repo.get("expert_model") or config_repo.get("model")
_model = initialize_expert_llm(provider, model)
except Exception as e:
_model = None

View File

@ -58,10 +58,10 @@ def ask_human(question: str) -> str:
# Record human response in database
try:
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.config_repository import get_config_repository
# Determine the source based on context
config = _global_memory.get("config", {})
config = get_config_repository().get_all()
# If chat_mode is enabled, use 'chat', otherwise determine if hil mode is active
if config.get("chat_mode", False):
source = "chat"

View File

@ -13,7 +13,8 @@ from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output
from ra_aid.tools.memory import _global_memory, log_work_event
from ra_aid.tools.memory import log_work_event
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
console = Console()
@ -107,8 +108,9 @@ def run_programming_task(
)
# Add config file if specified
if "config" in _global_memory and _global_memory["config"].get("aider_config"):
command.extend(["--config", _global_memory["config"]["aider_config"]])
config = get_config_repository().get_all()
if config.get("aider_config"):
command.extend(["--config", config["aider_config"]])
# if environment variable AIDER_FLAGS exists then parse
if "AIDER_FLAGS" in os.environ:
@ -147,8 +149,9 @@ def run_programming_task(
# Run the command interactively
print()
# Get provider/model specific latency coefficient
provider = _global_memory.get("config", {}).get("provider", "")
model = _global_memory.get("config", {}).get("model", "")
config = get_config_repository().get_all()
provider = config.get("provider", "")
model = config.get("model", "")
latency = (
models_params.get(provider, {})
.get(model, {})

View File

@ -9,6 +9,7 @@ from ra_aid.console.cowboy_messages import get_cowboy_message
from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output
from ra_aid.tools.memory import _global_memory, log_work_event
from ra_aid.database.repositories.config_repository import get_config_repository
console = Console()
@ -46,7 +47,7 @@ def run_shell_command(
4. Add flags e.g. git --no-pager in order to reduce interaction required by the human.
"""
# Check if we need approval
cowboy_mode = _global_memory.get("config", {}).get("cowboy_mode", False)
cowboy_mode = get_config_repository().get("cowboy_mode", False)
if cowboy_mode:
console.print("")
@ -74,7 +75,7 @@ def run_shell_command(
"success": False,
}
elif response == "c":
_global_memory["config"]["cowboy_mode"] = True
get_config_repository().set("cowboy_mode", True)
console.print("")
console.print(" " + get_cowboy_message())
console.print("")
@ -96,4 +97,4 @@ def run_shell_command(
except Exception as e:
print()
console.print(Panel(str(e), title="❌ Error", border_style="red"))
return {"output": str(e), "return_code": 1, "success": False}
return {"output": str(e), "return_code": 1, "success": False}

View File

@ -7,10 +7,25 @@ ensuring consistent test environments and proper isolation.
import os
from pathlib import Path
from unittest.mock import MagicMock
import pytest
@pytest.fixture()
def mock_config_repository():
"""Mock the config repository."""
from ra_aid.database.repositories.config_repository import get_config_repository
repo = MagicMock()
# Default config values
config_values = {"recursion_limit": 2}
repo.get_all.return_value = config_values
repo.get.side_effect = lambda key, default=None: config_values.get(key, default)
get_config_repository.return_value = repo
yield repo
@pytest.fixture(autouse=True)
def isolated_db_environment(tmp_path, monkeypatch):
"""

View File

@ -1,7 +1,7 @@
"""Unit tests for agent_utils.py."""
from typing import Any, Dict, Literal
from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, MagicMock
import litellm
import pytest
@ -19,6 +19,7 @@ from ra_aid.agent_utils import (
state_modifier,
)
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
@pytest.fixture
@ -29,40 +30,70 @@ def mock_model():
@pytest.fixture
def mock_memory():
"""Fixture providing a mock global memory store."""
with patch("ra_aid.agent_utils._global_memory") as mock_mem:
mock_mem.get.return_value = {}
yield mock_mem
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {}
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Setup get_all method to return all config values
mock_repo.get_all.return_value = config
# Setup set method to update config values
def set_config(key, value):
config[key] = value
mock_repo.set.side_effect = set_config
# Setup update method to update multiple config values
def update_config(update_dict):
config.update(update_dict)
mock_repo.update.side_effect = update_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
def test_get_model_token_limit_anthropic(mock_memory):
def test_get_model_token_limit_anthropic(mock_config_repository):
"""Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
def test_get_model_token_limit_openai(mock_memory):
def test_get_model_token_limit_openai(mock_config_repository):
"""Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
def test_get_model_token_limit_unknown(mock_memory):
def test_get_model_token_limit_unknown(mock_config_repository):
"""Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit is None
def test_get_model_token_limit_missing_config(mock_memory):
def test_get_model_token_limit_missing_config(mock_config_repository):
"""Test get_model_token_limit with missing configuration."""
config = {}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default")
assert token_limit is None
@ -108,9 +139,9 @@ def test_get_model_token_limit_unexpected_error():
assert token_limit is None
def test_create_agent_anthropic(mock_model, mock_memory):
def test_create_agent_anthropic(mock_model, mock_config_repository):
"""Test create_agent with Anthropic Claude model."""
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"}
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
mock_react.return_value = "react_agent"
@ -125,9 +156,9 @@ def test_create_agent_anthropic(mock_model, mock_memory):
)
def test_create_agent_openai(mock_model, mock_memory):
def test_create_agent_openai(mock_model, mock_config_repository):
"""Test create_agent with OpenAI model."""
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
mock_config_repository.update({"provider": "openai", "model": "gpt-4"})
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent"
@ -142,9 +173,9 @@ def test_create_agent_openai(mock_model, mock_memory):
)
def test_create_agent_no_token_limit(mock_model, mock_memory):
def test_create_agent_no_token_limit(mock_model, mock_config_repository):
"""Test create_agent when no token limit is found."""
mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"}
mock_config_repository.update({"provider": "unknown", "model": "unknown-model"})
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent"
@ -159,9 +190,9 @@ def test_create_agent_no_token_limit(mock_model, mock_memory):
)
def test_create_agent_missing_config(mock_model, mock_memory):
def test_create_agent_missing_config(mock_model, mock_config_repository):
"""Test create_agent with missing configuration."""
mock_memory.get.return_value = {"provider": "openai"}
mock_config_repository.update({"provider": "openai"})
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent"
@ -205,9 +236,9 @@ def test_state_modifier(mock_messages):
assert result[-1] == mock_messages[-1]
def test_create_agent_with_checkpointer(mock_model, mock_memory):
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
"""Test create_agent with checkpointer argument."""
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
mock_config_repository.update({"provider": "openai", "model": "gpt-4"})
mock_checkpointer = Mock()
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
@ -223,13 +254,13 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory):
)
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
"""Test create_agent sets up token limiting for Claude models when enabled."""
mock_memory.get.return_value = {
mock_config_repository.update({
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": True,
}
})
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
@ -246,13 +277,13 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
assert callable(args[1]["state_modifier"])
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory):
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
mock_memory.get.return_value = {
mock_config_repository.update({
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": False,
}
})
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
@ -267,7 +298,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
mock_react.assert_called_once_with(mock_model, [], version="v2")
def test_get_model_token_limit_research(mock_memory):
def test_get_model_token_limit_research(mock_config_repository):
"""Test get_model_token_limit with research provider and model."""
config = {
"provider": "openai",
@ -275,13 +306,15 @@ def test_get_model_token_limit_research(mock_memory):
"research_provider": "anthropic",
"research_model": "claude-2",
}
mock_config_repository.update(config)
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 150000}
token_limit = get_model_token_limit(config, "research")
assert token_limit == 150000
def test_get_model_token_limit_planner(mock_memory):
def test_get_model_token_limit_planner(mock_config_repository):
"""Test get_model_token_limit with planner provider and model."""
config = {
"provider": "openai",
@ -289,6 +322,8 @@ def test_get_model_token_limit_planner(mock_memory):
"planner_provider": "deepseek",
"planner_model": "dsm-1",
}
mock_config_repository.update(config)
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")

View File

@ -1,41 +1,35 @@
"""Tests for the is_informational_query and is_stage_requested functions."""
import pytest
from ra_aid.__main__ import is_informational_query, is_stage_requested
from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
def test_is_informational_query():
@pytest.fixture
def config_repo():
"""Fixture for config repository."""
with ConfigRepositoryManager() as repo:
yield repo
def test_is_informational_query(config_repo):
"""Test that is_informational_query only depends on research_only config setting."""
# Clear global memory to ensure clean state
_global_memory.clear()
# When research_only is True, should return True
_global_memory["config"] = {"research_only": True}
config_repo.set("research_only", True)
assert is_informational_query() is True
# When research_only is False, should return False
_global_memory["config"] = {"research_only": False}
config_repo.set("research_only", False)
assert is_informational_query() is False
# When config is empty, should return False (default)
_global_memory.clear()
_global_memory["config"] = {}
assert is_informational_query() is False
# When global memory is empty, should return False (default)
_global_memory.clear()
config_repo.update({})
assert is_informational_query() is False
def test_is_stage_requested():
"""Test that is_stage_requested always returns False now."""
# Clear global memory to ensure clean state
_global_memory.clear()
# Should always return False regardless of input
assert is_stage_requested("implementation") is False
assert is_stage_requested("anything_else") is False
# Even if we set implementation_requested in global memory
_global_memory["implementation_requested"] = True
assert is_stage_requested("implementation") is False
assert is_stage_requested("anything_else") is False

View File

@ -7,14 +7,50 @@ from ra_aid.__main__ import parse_arguments
from ra_aid.config import DEFAULT_RECURSION_LIMIT
from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.work_log_repository import WorkLogEntry
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository
@pytest.fixture(autouse=True)
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {}
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Setup set method to update config values
def set_config(key, value):
config[key] = value
mock_repo.set.side_effect = set_config
# Setup update method to update multiple config values
def update_config(config_dict):
config.update(config_dict)
mock_repo.update.side_effect = update_config
# Setup get_all method to return the config dict
def get_all_config():
return config.copy()
mock_repo.get_all.side_effect = get_all_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture
def mock_dependencies(monkeypatch):
"""Mock all dependencies needed for main()."""
# Initialize global memory with necessary keys to prevent KeyError
# Initialize global memory
_global_memory.clear()
_global_memory["config"] = {}
# Mock dependencies that interact with external systems
monkeypatch.setattr("ra_aid.__main__.check_dependencies", lambda: None)
@ -26,10 +62,9 @@ def mock_dependencies(monkeypatch):
# Mock LLM initialization
def mock_config_update(*args, **kwargs):
config = _global_memory.get("config", {})
config_repo = get_config_repository()
if kwargs.get("temperature"):
config["temperature"] = kwargs["temperature"]
_global_memory["config"] = config
config_repo.set("temperature", kwargs["temperature"])
return None
monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update)
@ -107,26 +142,52 @@ def mock_work_log_repository():
yield mock_repo
def test_recursion_limit_in_global_config(mock_dependencies):
def test_recursion_limit_in_global_config(mock_dependencies, mock_config_repository):
"""Test that recursion limit is correctly set in global config."""
import sys
from unittest.mock import patch
from ra_aid.__main__ import main
_global_memory.clear()
# Clear the mock repository before each test
mock_config_repository.update.reset_mock()
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
# Test default recursion limit
with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]):
main()
# Check that the recursion_limit value was included in the update call
mock_config_repository.update.assert_called()
# Get the call arguments
call_args = mock_config_repository.update.call_args_list
# Find the call that includes recursion_limit
recursion_limit_found = False
for args, _ in call_args:
config_dict = args[0]
if "recursion_limit" in config_dict and config_dict["recursion_limit"] == DEFAULT_RECURSION_LIMIT:
recursion_limit_found = True
break
assert recursion_limit_found, f"recursion_limit not found in update calls: {call_args}"
with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]):
main()
assert _global_memory["config"]["recursion_limit"] == DEFAULT_RECURSION_LIMIT
_global_memory.clear()
with patch.object(
sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"]
):
main()
assert _global_memory["config"]["recursion_limit"] == 50
# Reset mock to clear call history
mock_config_repository.update.reset_mock()
# Test custom recursion limit
with patch.object(sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"]):
main()
# Check that the recursion_limit value was included in the update call
mock_config_repository.update.assert_called()
# Get the call arguments
call_args = mock_config_repository.update.call_args_list
# Find the call that includes recursion_limit with value 50
recursion_limit_found = False
for args, _ in call_args:
config_dict = args[0]
if "recursion_limit" in config_dict and config_dict["recursion_limit"] == 50:
recursion_limit_found = True
break
assert recursion_limit_found, f"recursion_limit=50 not found in update calls: {call_args}"
def test_negative_recursion_limit():
@ -141,70 +202,83 @@ def test_zero_recursion_limit():
parse_arguments(["-m", "test message", "--recursion-limit", "0"])
def test_config_settings(mock_dependencies):
def test_config_settings(mock_dependencies, mock_config_repository):
"""Test that various settings are correctly applied in global config."""
import sys
from unittest.mock import patch
from ra_aid.__main__ import main
_global_memory.clear()
with patch.object(
sys,
"argv",
[
"ra-aid",
"-m",
"test message",
"--cowboy-mode",
"--research-only",
"--provider",
"anthropic",
"--model",
"claude-3-7-sonnet-20250219",
"--expert-provider",
"openai",
"--expert-model",
"gpt-4",
"--temperature",
"0.7",
"--disable-limit-tokens",
],
):
main()
config = _global_memory["config"]
assert config["cowboy_mode"] is True
assert config["research_only"] is True
assert config["provider"] == "anthropic"
assert config["model"] == "claude-3-7-sonnet-20250219"
assert config["expert_provider"] == "openai"
assert config["expert_model"] == "gpt-4"
assert config["limit_tokens"] is False
# Clear the mock repository before each test
mock_config_repository.update.reset_mock()
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
with patch.object(
sys,
"argv",
[
"ra-aid",
"-m",
"test message",
"--cowboy-mode",
"--research-only",
"--provider",
"anthropic",
"--model",
"claude-3-7-sonnet-20250219",
"--expert-provider",
"openai",
"--expert-model",
"gpt-4",
"--temperature",
"0.7",
"--disable-limit-tokens",
],
):
main()
# Verify config values are set via the update method
mock_config_repository.update.assert_called()
# Get the call arguments
call_args = mock_config_repository.update.call_args_list
# Check for config values in the update calls
for args, _ in call_args:
config_dict = args[0]
if "cowboy_mode" in config_dict:
assert config_dict["cowboy_mode"] is True
if "research_only" in config_dict:
assert config_dict["research_only"] is True
if "limit_tokens" in config_dict:
assert config_dict["limit_tokens"] is False
# Check provider and model settings via set method
mock_config_repository.set.assert_any_call("provider", "anthropic")
mock_config_repository.set.assert_any_call("model", "claude-3-7-sonnet-20250219")
mock_config_repository.set.assert_any_call("expert_provider", "openai")
mock_config_repository.set.assert_any_call("expert_model", "gpt-4")
def test_temperature_validation(mock_dependencies):
def test_temperature_validation(mock_dependencies, mock_config_repository):
"""Test that temperature argument is correctly passed to initialize_llm."""
import sys
from unittest.mock import patch, ANY
from ra_aid.__main__ import main
# Reset global memory for clean test
_global_memory.clear()
_global_memory["config"] = {}
# Test valid temperature (0.7)
with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm:
# Also patch any calls that would actually use the mocked initialize_llm function
with patch("ra_aid.__main__.run_research_agent", return_value=None):
with patch("ra_aid.__main__.run_planning_agent", return_value=None):
with patch.object(
sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"]
):
main()
# Check if temperature was stored in config correctly
assert _global_memory["config"]["temperature"] == 0.7
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
# Test valid temperature (0.7)
with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm:
# Also patch any calls that would actually use the mocked initialize_llm function
with patch("ra_aid.__main__.run_research_agent", return_value=None):
with patch("ra_aid.__main__.run_planning_agent", return_value=None):
with patch.object(
sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"]
):
main()
# Verify that the temperature was set in the config repository
mock_config_repository.set.assert_any_call("temperature", 0.7)
# Test invalid temperature (2.1)
with pytest.raises(SystemExit):
@ -230,61 +304,67 @@ def test_missing_message():
assert args.message == "test"
def test_research_model_provider_args(mock_dependencies):
def test_research_model_provider_args(mock_dependencies, mock_config_repository):
"""Test that research-specific model/provider args are correctly stored in config."""
import sys
from unittest.mock import patch
from ra_aid.__main__ import main
_global_memory.clear()
with patch.object(
sys,
"argv",
[
"ra-aid",
"-m",
"test message",
"--research-provider",
"anthropic",
"--research-model",
"claude-3-haiku-20240307",
"--planner-provider",
"openai",
"--planner-model",
"gpt-4",
],
):
main()
config = _global_memory["config"]
assert config["research_provider"] == "anthropic"
assert config["research_model"] == "claude-3-haiku-20240307"
assert config["planner_provider"] == "openai"
assert config["planner_model"] == "gpt-4"
# Reset mocks
mock_config_repository.set.reset_mock()
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
with patch.object(
sys,
"argv",
[
"ra-aid",
"-m",
"test message",
"--research-provider",
"anthropic",
"--research-model",
"claude-3-haiku-20240307",
"--planner-provider",
"openai",
"--planner-model",
"gpt-4",
],
):
main()
# Verify the mock repo's set method was called with the expected values
mock_config_repository.set.assert_any_call("research_provider", "anthropic")
mock_config_repository.set.assert_any_call("research_model", "claude-3-haiku-20240307")
mock_config_repository.set.assert_any_call("planner_provider", "openai")
mock_config_repository.set.assert_any_call("planner_model", "gpt-4")
def test_planner_model_provider_args(mock_dependencies):
def test_planner_model_provider_args(mock_dependencies, mock_config_repository):
"""Test that planner provider/model args fall back to main config when not specified."""
import sys
from unittest.mock import patch
from ra_aid.__main__ import main
_global_memory.clear()
with patch.object(
sys,
"argv",
["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"],
):
main()
config = _global_memory["config"]
assert config["planner_provider"] == "openai"
assert config["planner_model"] == "gpt-4"
# Reset mocks
mock_config_repository.set.reset_mock()
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
with patch.object(
sys,
"argv",
["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"],
):
main()
# Verify the mock repo's set method was called with the expected values
mock_config_repository.set.assert_any_call("planner_provider", "openai")
mock_config_repository.set.assert_any_call("planner_model", "gpt-4")
def test_use_aider_flag(mock_dependencies):
def test_use_aider_flag(mock_dependencies, mock_config_repository):
"""Test that use-aider flag is correctly stored in config."""
import sys
from unittest.mock import patch
@ -292,44 +372,68 @@ def test_use_aider_flag(mock_dependencies):
from ra_aid.__main__ import main
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
_global_memory.clear()
# Reset mocks
mock_config_repository.update.reset_mock()
# Reset to default state
set_modification_tools(False)
# Check default behavior (use_aider=False)
with patch.object(
sys,
"argv",
["ra-aid", "-m", "test message"],
):
main()
config = _global_memory["config"]
assert config.get("use_aider") is False
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
# Check default behavior (use_aider=False)
with patch.object(
sys,
"argv",
["ra-aid", "-m", "test message"],
):
main()
# Verify use_aider is set to False in the update call
mock_config_repository.update.assert_called()
# Get the call arguments
call_args = mock_config_repository.update.call_args_list
# Find the call that includes use_aider
use_aider_found = False
for args, _ in call_args:
config_dict = args[0]
if "use_aider" in config_dict and config_dict["use_aider"] is False:
use_aider_found = True
break
assert use_aider_found, f"use_aider=False not found in update calls: {call_args}"
# Check that file tools are enabled by default
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" in tool_names
assert "put_complete_file_contents" in tool_names
assert "run_programming_task" not in tool_names
# Check that file tools are enabled by default
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" in tool_names
assert "put_complete_file_contents" in tool_names
assert "run_programming_task" not in tool_names
_global_memory.clear()
# Reset mocks
mock_config_repository.update.reset_mock()
# Check with --use-aider flag
with patch.object(
sys,
"argv",
["ra-aid", "-m", "test message", "--use-aider"],
):
main()
config = _global_memory["config"]
assert config.get("use_aider") is True
# Check with --use-aider flag
with patch.object(
sys,
"argv",
["ra-aid", "-m", "test message", "--use-aider"],
):
main()
# Verify use_aider is set to True in the update call
mock_config_repository.update.assert_called()
# Get the call arguments
call_args = mock_config_repository.update.call_args_list
# Find the call that includes use_aider
use_aider_found = False
for args, _ in call_args:
config_dict = args[0]
if "use_aider" in config_dict and config_dict["use_aider"] is True:
use_aider_found = True
break
assert use_aider_found, f"use_aider=True not found in update calls: {call_args}"
# Check that run_programming_task is enabled
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" not in tool_names
assert "put_complete_file_contents" not in tool_names
assert "run_programming_task" in tool_names
# Check that run_programming_task is enabled
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" not in tool_names
assert "put_complete_file_contents" not in tool_names
assert "run_programming_task" in tool_names
# Reset to default state for other tests
set_modification_tools(False)

View File

@ -1,3 +1,4 @@
import os
import pytest
from unittest.mock import patch, MagicMock
@ -7,6 +8,36 @@ from ra_aid.tools.programmer import (
run_programming_task,
)
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.config_repository import get_config_repository
@pytest.fixture(autouse=True)
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {
"recursion_limit": 2,
"provider": "anthropic",
"model": "claude-3-5-sonnet-20241022",
"temperature": 0.01,
"aider_config": "/path/to/config.yml"
}
# Setup get_all method to return the config dict
mock_repo.get_all.return_value = config
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture(autouse=True)
def mock_related_files_repository():
@ -125,13 +156,9 @@ def test_parse_aider_flags(input_flags, expected, description):
assert result == expected, f"Failed test case: {description}"
def test_aider_config_flag(mocker, mock_related_files_repository):
def test_aider_config_flag(mocker, mock_config_repository, mock_related_files_repository):
"""Test that aider config flag is properly included in the command when specified."""
# Mock config in global memory but not related files (using repository now)
mock_memory = {
"config": {"aider_config": "/path/to/config.yml"},
}
mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory)
# Config is mocked by mock_config_repository fixture
# Mock the run_interactive_command to capture the command that would be run
mock_run = mocker.patch(
@ -146,15 +173,14 @@ def test_aider_config_flag(mocker, mock_related_files_repository):
assert args[config_index + 1] == "/path/to/config.yml"
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_related_files_repository):
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_config_repository, mock_related_files_repository):
"""Test path normalization and deduplication in run_programming_task."""
# Create a temporary test file
test_file = tmp_path / "test.py"
test_file.write_text("")
new_file = tmp_path / "new.py"
# Mock dependencies - only need to mock config part of global memory now
mocker.patch("ra_aid.tools.programmer._global_memory", {"config": {}})
# Config is mocked by mock_config_repository fixture
mocker.patch(
"ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider"
)

View File

@ -13,6 +13,7 @@ from ra_aid.tools.agent import (
from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry
from ra_aid.database.repositories.config_repository import get_config_repository
@pytest.fixture
@ -43,6 +44,34 @@ def mock_related_files_repository():
yield mock_repo
@pytest.fixture(autouse=True)
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {
"recursion_limit": 2,
"provider": "anthropic",
"model": "claude-3-5-sonnet-20241022",
"temperature": 0.01
}
# Setup get_all method to return the config dict
mock_repo.get_all.return_value = config
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture(autouse=True)
def mock_work_log_repository():
"""Mock the WorkLogRepository to avoid database operations during tests"""

View File

@ -1,8 +1,8 @@
from unittest.mock import patch
from unittest.mock import patch, MagicMock
import pytest
from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
from ra_aid.tools.shell import run_shell_command
@ -25,9 +25,38 @@ def mock_run_interactive():
yield mock
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive):
@pytest.fixture(autouse=True)
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {
"cowboy_mode": False
}
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Setup set method to update config values
def set_config(key, value):
config[key] = value
mock_repo.set.side_effect = set_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
"""Test shell command execution in cowboy mode (no approval)"""
_global_memory["config"] = {"cowboy_mode": True}
# Set cowboy mode to True using the repository
mock_config_repository.set("cowboy_mode", True)
result = run_shell_command.invoke({"command": "echo test"})
@ -37,9 +66,10 @@ def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interacti
mock_prompt.ask.assert_not_called()
def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive):
def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
"""Test that cowboy mode displays a properly formatted cowboy message with correct spacing"""
_global_memory["config"] = {"cowboy_mode": True}
# Set cowboy mode to True using the repository
mock_config_repository.set("cowboy_mode", True)
with patch("ra_aid.tools.shell.get_cowboy_message") as mock_get_message:
mock_get_message.return_value = "🤠 Test cowboy message!"
@ -53,10 +83,11 @@ def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_intera
def test_shell_command_interactive_approved(
mock_console, mock_prompt, mock_run_interactive
mock_console, mock_prompt, mock_run_interactive, mock_config_repository
):
"""Test shell command execution with interactive approval"""
_global_memory["config"] = {"cowboy_mode": False}
# Set cowboy mode to False using the repository
mock_config_repository.set("cowboy_mode", False)
mock_prompt.ask.return_value = "y"
result = run_shell_command.invoke({"command": "echo test"})
@ -74,10 +105,11 @@ def test_shell_command_interactive_approved(
def test_shell_command_interactive_rejected(
mock_console, mock_prompt, mock_run_interactive
mock_console, mock_prompt, mock_run_interactive, mock_config_repository
):
"""Test shell command rejection in interactive mode"""
_global_memory["config"] = {"cowboy_mode": False}
# Set cowboy mode to False using the repository
mock_config_repository.set("cowboy_mode", False)
mock_prompt.ask.return_value = "n"
result = run_shell_command.invoke({"command": "echo test"})
@ -95,13 +127,14 @@ def test_shell_command_interactive_rejected(
mock_run_interactive.assert_not_called()
def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive):
def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
"""Test handling of shell command execution errors"""
_global_memory["config"] = {"cowboy_mode": True}
# Set cowboy mode to True using the repository
mock_config_repository.set("cowboy_mode", True)
mock_run_interactive.side_effect = Exception("Command failed")
result = run_shell_command.invoke({"command": "invalid command"})
assert result["success"] is False
assert result["return_code"] == 1
assert "Command failed" in result["output"]
assert "Command failed" in result["output"]