config repository

This commit is contained in:
AI Christianson 2025-03-04 21:01:08 -05:00
parent 3e68dd3fa6
commit 5bd8c76a22
19 changed files with 695 additions and 290 deletions

View File

@ -58,6 +58,10 @@ from ra_aid.database.repositories.related_files_repository import (
from ra_aid.database.repositories.work_log_repository import ( from ra_aid.database.repositories.work_log_repository import (
WorkLogRepositoryManager WorkLogRepositoryManager
) )
from ra_aid.database.repositories.config_repository import (
ConfigRepositoryManager,
get_config_repository
)
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.console.output import cpm from ra_aid.console.output import cpm
@ -77,7 +81,7 @@ from ra_aid.prompts.chat_prompts import CHAT_PROMPT
from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_CHAT from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_CHAT
from ra_aid.tool_configs import get_chat_tools, set_modification_tools from ra_aid.tool_configs import get_chat_tools, set_modification_tools
from ra_aid.tools.human import ask_human from ra_aid.tools.human import ask_human
from ra_aid.tools.memory import _global_memory, get_memory_value from ra_aid.tools.memory import get_memory_value
logger = get_logger(__name__) logger = get_logger(__name__)
@ -338,7 +342,7 @@ implementation_memory = MemorySaver()
def is_informational_query() -> bool: def is_informational_query() -> bool:
"""Determine if the current query is informational based on config settings.""" """Determine if the current query is informational based on config settings."""
return _global_memory.get("config", {}).get("research_only", False) return get_config_repository().get("research_only", False)
def is_stage_requested(stage: str) -> bool: def is_stage_requested(stage: str) -> bool:
@ -404,13 +408,17 @@ def main():
except Exception as e: except Exception as e:
logger.error(f"Database migration error: {str(e)}") logger.error(f"Database migration error: {str(e)}")
# Initialize empty config dictionary to be populated later
config = {}
# Initialize repositories with database connection # Initialize repositories with database connection
with KeyFactRepositoryManager(db) as key_fact_repo, \ with KeyFactRepositoryManager(db) as key_fact_repo, \
KeySnippetRepositoryManager(db) as key_snippet_repo, \ KeySnippetRepositoryManager(db) as key_snippet_repo, \
HumanInputRepositoryManager(db) as human_input_repo, \ HumanInputRepositoryManager(db) as human_input_repo, \
ResearchNoteRepositoryManager(db) as research_note_repo, \ ResearchNoteRepositoryManager(db) as research_note_repo, \
RelatedFilesRepositoryManager() as related_files_repo, \ RelatedFilesRepositoryManager() as related_files_repo, \
WorkLogRepositoryManager() as work_log_repo: WorkLogRepositoryManager() as work_log_repo, \
ConfigRepositoryManager(config) as config_repo:
# This initializes all repositories and makes them available via their respective get methods # This initializes all repositories and makes them available via their respective get methods
logger.debug("Initialized KeyFactRepository") logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository") logger.debug("Initialized KeySnippetRepository")
@ -418,6 +426,7 @@ def main():
logger.debug("Initialized ResearchNoteRepository") logger.debug("Initialized ResearchNoteRepository")
logger.debug("Initialized RelatedFilesRepository") logger.debug("Initialized RelatedFilesRepository")
logger.debug("Initialized WorkLogRepository") logger.debug("Initialized WorkLogRepository")
logger.debug("Initialized ConfigRepository")
# Check dependencies before proceeding # Check dependencies before proceeding
check_dependencies() check_dependencies()
@ -520,13 +529,13 @@ def main():
"limit_tokens": args.disable_limit_tokens, "limit_tokens": args.disable_limit_tokens,
} }
# Store config in global memory # Store config in repository
_global_memory["config"] = config config_repo.update(config)
_global_memory["config"]["provider"] = args.provider config_repo.set("provider", args.provider)
_global_memory["config"]["model"] = args.model config_repo.set("model", args.model)
_global_memory["config"]["expert_provider"] = args.expert_provider config_repo.set("expert_provider", args.expert_provider)
_global_memory["config"]["expert_model"] = args.expert_model config_repo.set("expert_model", args.expert_model)
_global_memory["config"]["temperature"] = args.temperature config_repo.set("temperature", args.temperature)
# Set modification tools based on use_aider flag # Set modification tools based on use_aider flag
set_modification_tools(args.use_aider) set_modification_tools(args.use_aider)
@ -594,33 +603,27 @@ def main():
"test_cmd_timeout": args.test_cmd_timeout, "test_cmd_timeout": args.test_cmd_timeout,
} }
# Store config in global memory for access by is_informational_query # Store config in repository
_global_memory["config"] = config config_repo.update(config)
# Store base provider/model configuration # Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider config_repo.set("provider", args.provider)
_global_memory["config"]["model"] = args.model config_repo.set("model", args.model)
# Store expert provider/model (no fallback) # Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider config_repo.set("expert_provider", args.expert_provider)
_global_memory["config"]["expert_model"] = args.expert_model config_repo.set("expert_model", args.expert_model)
# Store planner config with fallback to base values # Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = ( config_repo.set("planner_provider", args.planner_provider or args.provider)
args.planner_provider or args.provider config_repo.set("planner_model", args.planner_model or args.model)
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values # Store research config with fallback to base values
_global_memory["config"]["research_provider"] = ( config_repo.set("research_provider", args.research_provider or args.provider)
args.research_provider or args.provider config_repo.set("research_model", args.research_model or args.model)
)
_global_memory["config"]["research_model"] = (
args.research_model or args.model
)
# Store temperature in global config # Store temperature in config
_global_memory["config"]["temperature"] = args.temperature config_repo.set("temperature", args.temperature)
# Set modification tools based on use_aider flag # Set modification tools based on use_aider flag
set_modification_tools(args.use_aider) set_modification_tools(args.use_aider)

View File

@ -94,11 +94,11 @@ from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
from ra_aid.tools.memory import ( from ra_aid.tools.memory import (
_global_memory,
get_memory_value, get_memory_value,
get_related_files, get_related_files,
log_work_event, log_work_event,
) )
from ra_aid.database.repositories.config_repository import get_config_repository
console = Console() console = Console()
@ -302,7 +302,7 @@ def create_agent(
config['limit_tokens'] = False. config['limit_tokens'] = False.
""" """
try: try:
config = _global_memory.get("config", {}) config = get_config_repository().get_all()
max_input_tokens = ( max_input_tokens = (
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
) )
@ -319,7 +319,7 @@ def create_agent(
except Exception as e: except Exception as e:
# Default to REACT agent if provider/model detection fails # Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = _global_memory.get("config", {}) config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type) max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
return create_react_agent(model, tools, **agent_kwargs) return create_react_agent(model, tools, **agent_kwargs)
@ -443,7 +443,7 @@ def run_research_agent(
new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "", new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "",
) )
config = _global_memory.get("config", {}) if not config else config config = get_config_repository().get_all() if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = { run_config = {
"configurable": {"thread_id": thread_id}, "configurable": {"thread_id": thread_id},
@ -575,7 +575,7 @@ def run_web_research_agent(
related_files=related_files, related_files=related_files,
) )
config = _global_memory.get("config", {}) if not config else config config = get_config_repository().get_all() if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = { run_config = {
@ -709,7 +709,7 @@ def run_planning_agent(
), ),
) )
config = _global_memory.get("config", {}) if not config else config config = get_config_repository().get_all() if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = { run_config = {
"configurable": {"thread_id": thread_id}, "configurable": {"thread_id": thread_id},
@ -824,7 +824,7 @@ def run_task_implementation_agent(
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
human_section=( human_section=(
HUMAN_PROMPT_SECTION_IMPLEMENTATION HUMAN_PROMPT_SECTION_IMPLEMENTATION
if _global_memory.get("config", {}).get("hil", False) if get_config_repository().get("hil", False)
else "" else ""
), ),
web_research_section=( web_research_section=(
@ -834,7 +834,7 @@ def run_task_implementation_agent(
), ),
) )
config = _global_memory.get("config", {}) if not config else config config = get_config_repository().get_all() if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = { run_config = {
"configurable": {"thread_id": thread_id}, "configurable": {"thread_id": thread_id},

View File

@ -19,9 +19,10 @@ logger = logging.getLogger(__name__)
from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.llm import initialize_llm from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
from ra_aid.tools.memory import log_work_event, _global_memory from ra_aid.tools.memory import log_work_event
console = Console() console = Console()
@ -149,7 +150,7 @@ def run_key_facts_gc_agent() -> None:
formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()]) formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()])
# Retrieve configuration # Retrieve configuration
llm_config = _global_memory.get("config", {}) llm_config = get_config_repository().get_all()
# Initialize the LLM model # Initialize the LLM model
model = initialize_llm( model = initialize_llm(

View File

@ -16,9 +16,10 @@ from rich.panel import Panel
from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.llm import initialize_llm from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
from ra_aid.tools.memory import log_work_event, _global_memory from ra_aid.tools.memory import log_work_event
console = Console() console = Console()
@ -153,7 +154,7 @@ def run_key_snippets_gc_agent() -> None:
]) ])
# Retrieve configuration # Retrieve configuration
llm_config = _global_memory.get("config", {}) llm_config = get_config_repository().get_all()
# Initialize the LLM model # Initialize the LLM model
model = initialize_llm( model = initialize_llm(

View File

@ -19,9 +19,10 @@ logger = logging.getLogger(__name__)
from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.research_note_repository import get_research_note_repository from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.llm import initialize_llm from ra_aid.llm import initialize_llm
from ra_aid.model_formatters.research_notes_formatter import format_research_note from ra_aid.model_formatters.research_notes_formatter import format_research_note
from ra_aid.tools.memory import log_work_event, _global_memory from ra_aid.tools.memory import log_work_event
console = Console() console = Console()
@ -154,7 +155,7 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None:
formatted_notes = "\n".join([f"Note #{k}: {v}" for k, v in notes_dict.items()]) formatted_notes = "\n".join([f"Note #{k}: {v}" for k, v in notes_dict.items()])
# Retrieve configuration # Retrieve configuration
llm_config = _global_memory.get("config", {}) llm_config = get_config_repository().get_all()
# Initialize the LLM model # Initialize the LLM model
model = initialize_llm( model = initialize_llm(

View File

@ -0,0 +1,165 @@
"""Repository for managing configuration values."""
import contextvars
from typing import Any, Dict, Optional
# Create contextvar to hold the ConfigRepository instance
config_repo_var = contextvars.ContextVar("config_repo", default=None)
class ConfigRepository:
"""
Repository for managing configuration values in memory.
This class provides methods to get, set, update, and retrieve all configuration values.
It does not require database models and operates entirely in memory.
"""
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
"""
Initialize the ConfigRepository.
Args:
initial_config: Optional dictionary of initial configuration values
"""
self._config: Dict[str, Any] = {}
# Initialize with default values from config.py
from ra_aid.config import (
DEFAULT_RECURSION_LIMIT,
DEFAULT_MAX_TEST_CMD_RETRIES,
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
DEFAULT_TEST_CMD_TIMEOUT,
VALID_PROVIDERS,
)
self._config = {
"recursion_limit": DEFAULT_RECURSION_LIMIT,
"max_test_cmd_retries": DEFAULT_MAX_TEST_CMD_RETRIES,
"max_tool_failures": DEFAULT_MAX_TOOL_FAILURES,
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
"retry_fallback_count": RETRY_FALLBACK_COUNT,
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
"valid_providers": VALID_PROVIDERS,
}
# Update with any provided initial configuration
if initial_config:
self._config.update(initial_config)
def get(self, key: str, default: Any = None) -> Any:
"""
Get a configuration value by key.
Args:
key: Configuration key to retrieve
default: Default value to return if key is not found
Returns:
The configuration value or default if not found
"""
return self._config.get(key, default)
def set(self, key: str, value: Any) -> None:
"""
Set a configuration value by key.
Args:
key: Configuration key to set
value: Value to set for the key
"""
self._config[key] = value
def update(self, config_dict: Dict[str, Any]) -> None:
"""
Update multiple configuration values at once.
Args:
config_dict: Dictionary of configuration key-value pairs to update
"""
self._config.update(config_dict)
def get_all(self) -> Dict[str, Any]:
"""
Get all configuration values.
Returns:
Dictionary containing all configuration values
"""
return self._config.copy()
class ConfigRepositoryManager:
"""
Context manager for ConfigRepository.
This class provides a context manager interface for ConfigRepository,
using the contextvars approach for thread safety.
Example:
with ConfigRepositoryManager() as repo:
# Use the repository
value = repo.get("key")
repo.set("key", new_value)
"""
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
"""
Initialize the ConfigRepositoryManager.
Args:
initial_config: Optional dictionary of initial configuration values
"""
self.initial_config = initial_config
def __enter__(self) -> 'ConfigRepository':
"""
Initialize the ConfigRepository and return it.
Returns:
ConfigRepository: The initialized repository
"""
repo = ConfigRepository(self.initial_config)
config_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
config_repo_var.set(None)
# Don't suppress exceptions
return False
def get_config_repository() -> ConfigRepository:
"""
Get the current ConfigRepository instance.
Returns:
ConfigRepository: The current repository instance
Raises:
RuntimeError: If no repository is set in the current context
"""
repo = config_repo_var.get()
if repo is None:
raise RuntimeError(
"ConfigRepository not initialized in current context. "
"Make sure to use ConfigRepositoryManager."
)
return repo

View File

@ -27,6 +27,7 @@ from ra_aid.tools.agent import (
request_web_research, request_web_research,
) )
from ra_aid.tools.memory import plan_implementation_completed from ra_aid.tools.memory import plan_implementation_completed
from ra_aid.database.repositories.config_repository import get_config_repository
def set_modification_tools(use_aider=False): def set_modification_tools(use_aider=False):
@ -98,13 +99,11 @@ def get_all_tools() -> list[BaseTool]:
# Define constant tool groups # Define constant tool groups
# Get config from global memory for use_aider value # Get config from repository for use_aider value
_config = {} _config = {}
try: try:
from ra_aid.tools.memory import _global_memory _config = get_config_repository().get_all()
except (ImportError, RuntimeError):
_config = _global_memory.get("config", {})
except ImportError:
pass pass
READ_ONLY_TOOLS = get_read_only_tools(use_aider=_config.get("use_aider", False)) READ_ONLY_TOOLS = get_read_only_tools(use_aider=_config.get("use_aider", False))
@ -139,10 +138,8 @@ def get_research_tools(
# Get config for use_aider value # Get config for use_aider value
use_aider = False use_aider = False
try: try:
from ra_aid.tools.memory import _global_memory use_aider = get_config_repository().get("use_aider", False)
except (ImportError, RuntimeError):
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
pass pass
# Start with read-only tools # Start with read-only tools
@ -180,10 +177,8 @@ def get_planning_tools(
# Get config for use_aider value # Get config for use_aider value
use_aider = False use_aider = False
try: try:
from ra_aid.tools.memory import _global_memory use_aider = get_config_repository().get("use_aider", False)
except (ImportError, RuntimeError):
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
pass pass
# Start with read-only tools # Start with read-only tools
@ -219,10 +214,8 @@ def get_implementation_tools(
# Get config for use_aider value # Get config for use_aider value
use_aider = False use_aider = False
try: try:
from ra_aid.tools.memory import _global_memory use_aider = get_config_repository().get("use_aider", False)
except (ImportError, RuntimeError):
use_aider = _global_memory.get("config", {}).get("use_aider", False)
except ImportError:
pass pass
# Start with read-only tools # Start with read-only tools
@ -285,4 +278,4 @@ def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = Fal
if web_research_enabled: if web_research_enabled:
tools.append(request_web_research) tools.append(request_web_research)
return tools return tools

View File

@ -18,13 +18,13 @@ from ra_aid.console.formatting import print_error
from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.exceptions import AgentInterrupt from ra_aid.exceptions import AgentInterrupt
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
from ra_aid.tools.memory import _global_memory
from ..console import print_task_header from ..console import print_task_header
from ..llm import initialize_llm from ..llm import initialize_llm
@ -52,7 +52,7 @@ def request_research(query: str) -> ResearchResult:
query: The research question or project description query: The research question or project description
""" """
# Initialize model from config # Initialize model from config
config = _global_memory.get("config", {}) config = get_config_repository().get_all()
model = initialize_llm( model = initialize_llm(
config.get("provider", "anthropic"), config.get("provider", "anthropic"),
config.get("model", "claude-3-7-sonnet-20250219"), config.get("model", "claude-3-7-sonnet-20250219"),
@ -165,7 +165,7 @@ def request_web_research(query: str) -> ResearchResult:
query: The research question or project description query: The research question or project description
""" """
# Initialize model from config # Initialize model from config
config = _global_memory.get("config", {}) config = get_config_repository().get_all()
model = initialize_llm( model = initialize_llm(
config.get("provider", "anthropic"), config.get("provider", "anthropic"),
config.get("model", "claude-3-7-sonnet-20250219"), config.get("model", "claude-3-7-sonnet-20250219"),
@ -246,7 +246,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
query: The research question or project description query: The research question or project description
""" """
# Initialize model from config # Initialize model from config
config = _global_memory.get("config", {}) config = get_config_repository().get_all()
model = initialize_llm( model = initialize_llm(
config.get("provider", "anthropic"), config.get("provider", "anthropic"),
config.get("model", "claude-3-7-sonnet-20250219"), config.get("model", "claude-3-7-sonnet-20250219"),
@ -335,7 +335,7 @@ def request_task_implementation(task_spec: str) -> str:
task_spec: REQUIRED The full task specification (markdown format, typically one part of the overall plan) task_spec: REQUIRED The full task specification (markdown format, typically one part of the overall plan)
""" """
# Initialize model from config # Initialize model from config
config = _global_memory.get("config", {}) config = get_config_repository().get_all()
model = initialize_llm( model = initialize_llm(
config.get("provider", "anthropic"), config.get("provider", "anthropic"),
config.get("model", "claude-3-5-sonnet-20241022"), config.get("model", "claude-3-5-sonnet-20241022"),
@ -474,7 +474,7 @@ def request_implementation(task_spec: str) -> str:
task_spec: The task specification to plan implementation for task_spec: The task specification to plan implementation for
""" """
# Initialize model from config # Initialize model from config
config = _global_memory.get("config", {}) config = get_config_repository().get_all()
model = initialize_llm( model = initialize_llm(
config.get("provider", "anthropic"), config.get("provider", "anthropic"),
config.get("model", "claude-3-5-sonnet-20241022"), config.get("model", "claude-3-5-sonnet-20241022"),

View File

@ -13,11 +13,12 @@ from ..database.repositories.key_fact_repository import get_key_fact_repository
from ..database.repositories.key_snippet_repository import get_key_snippet_repository from ..database.repositories.key_snippet_repository import get_key_snippet_repository
from ..database.repositories.related_files_repository import get_related_files_repository from ..database.repositories.related_files_repository import get_related_files_repository
from ..database.repositories.research_note_repository import get_research_note_repository from ..database.repositories.research_note_repository import get_research_note_repository
from ..database.repositories.config_repository import get_config_repository
from ..llm import initialize_expert_llm from ..llm import initialize_expert_llm
from ..model_formatters import format_key_facts_dict from ..model_formatters import format_key_facts_dict
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
from ..model_formatters.research_notes_formatter import format_research_notes_dict from ..model_formatters.research_notes_formatter import format_research_notes_dict
from .memory import _global_memory, get_memory_value from .memory import get_memory_value
console = Console() console = Console()
_model = None _model = None
@ -27,9 +28,9 @@ def get_model():
global _model global _model
try: try:
if _model is None: if _model is None:
config = _global_memory["config"] config_repo = get_config_repository()
provider = config.get("expert_provider") or config.get("provider") provider = config_repo.get("expert_provider") or config_repo.get("provider")
model = config.get("expert_model") or config.get("model") model = config_repo.get("expert_model") or config_repo.get("model")
_model = initialize_expert_llm(provider, model) _model = initialize_expert_llm(provider, model)
except Exception as e: except Exception as e:
_model = None _model = None

View File

@ -58,10 +58,10 @@ def ask_human(question: str) -> str:
# Record human response in database # Record human response in database
try: try:
from ra_aid.database.repositories.human_input_repository import get_human_input_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.tools.memory import _global_memory from ra_aid.database.repositories.config_repository import get_config_repository
# Determine the source based on context # Determine the source based on context
config = _global_memory.get("config", {}) config = get_config_repository().get_all()
# If chat_mode is enabled, use 'chat', otherwise determine if hil mode is active # If chat_mode is enabled, use 'chat', otherwise determine if hil mode is active
if config.get("chat_mode", False): if config.get("chat_mode", False):
source = "chat" source = "chat"

View File

@ -13,7 +13,8 @@ from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
from ra_aid.proc.interactive import run_interactive_command from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output from ra_aid.text.processing import truncate_output
from ra_aid.tools.memory import _global_memory, log_work_event from ra_aid.tools.memory import log_work_event
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository
console = Console() console = Console()
@ -107,8 +108,9 @@ def run_programming_task(
) )
# Add config file if specified # Add config file if specified
if "config" in _global_memory and _global_memory["config"].get("aider_config"): config = get_config_repository().get_all()
command.extend(["--config", _global_memory["config"]["aider_config"]]) if config.get("aider_config"):
command.extend(["--config", config["aider_config"]])
# if environment variable AIDER_FLAGS exists then parse # if environment variable AIDER_FLAGS exists then parse
if "AIDER_FLAGS" in os.environ: if "AIDER_FLAGS" in os.environ:
@ -147,8 +149,9 @@ def run_programming_task(
# Run the command interactively # Run the command interactively
print() print()
# Get provider/model specific latency coefficient # Get provider/model specific latency coefficient
provider = _global_memory.get("config", {}).get("provider", "") config = get_config_repository().get_all()
model = _global_memory.get("config", {}).get("model", "") provider = config.get("provider", "")
model = config.get("model", "")
latency = ( latency = (
models_params.get(provider, {}) models_params.get(provider, {})
.get(model, {}) .get(model, {})

View File

@ -9,6 +9,7 @@ from ra_aid.console.cowboy_messages import get_cowboy_message
from ra_aid.proc.interactive import run_interactive_command from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output from ra_aid.text.processing import truncate_output
from ra_aid.tools.memory import _global_memory, log_work_event from ra_aid.tools.memory import _global_memory, log_work_event
from ra_aid.database.repositories.config_repository import get_config_repository
console = Console() console = Console()
@ -46,7 +47,7 @@ def run_shell_command(
4. Add flags e.g. git --no-pager in order to reduce interaction required by the human. 4. Add flags e.g. git --no-pager in order to reduce interaction required by the human.
""" """
# Check if we need approval # Check if we need approval
cowboy_mode = _global_memory.get("config", {}).get("cowboy_mode", False) cowboy_mode = get_config_repository().get("cowboy_mode", False)
if cowboy_mode: if cowboy_mode:
console.print("") console.print("")
@ -74,7 +75,7 @@ def run_shell_command(
"success": False, "success": False,
} }
elif response == "c": elif response == "c":
_global_memory["config"]["cowboy_mode"] = True get_config_repository().set("cowboy_mode", True)
console.print("") console.print("")
console.print(" " + get_cowboy_message()) console.print(" " + get_cowboy_message())
console.print("") console.print("")
@ -96,4 +97,4 @@ def run_shell_command(
except Exception as e: except Exception as e:
print() print()
console.print(Panel(str(e), title="❌ Error", border_style="red")) console.print(Panel(str(e), title="❌ Error", border_style="red"))
return {"output": str(e), "return_code": 1, "success": False} return {"output": str(e), "return_code": 1, "success": False}

View File

@ -7,10 +7,25 @@ ensuring consistent test environments and proper isolation.
import os import os
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock
import pytest import pytest
@pytest.fixture()
def mock_config_repository():
"""Mock the config repository."""
from ra_aid.database.repositories.config_repository import get_config_repository
repo = MagicMock()
# Default config values
config_values = {"recursion_limit": 2}
repo.get_all.return_value = config_values
repo.get.side_effect = lambda key, default=None: config_values.get(key, default)
get_config_repository.return_value = repo
yield repo
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def isolated_db_environment(tmp_path, monkeypatch): def isolated_db_environment(tmp_path, monkeypatch):
""" """

View File

@ -1,7 +1,7 @@
"""Unit tests for agent_utils.py.""" """Unit tests for agent_utils.py."""
from typing import Any, Dict, Literal from typing import Any, Dict, Literal
from unittest.mock import Mock, patch from unittest.mock import Mock, patch, MagicMock
import litellm import litellm
import pytest import pytest
@ -19,6 +19,7 @@ from ra_aid.agent_utils import (
state_modifier, state_modifier,
) )
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
@pytest.fixture @pytest.fixture
@ -29,40 +30,70 @@ def mock_model():
@pytest.fixture @pytest.fixture
def mock_memory(): def mock_config_repository():
"""Fixture providing a mock global memory store.""" """Mock the ConfigRepository to avoid database operations during tests"""
with patch("ra_aid.agent_utils._global_memory") as mock_mem: with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
mock_mem.get.return_value = {} # Setup a mock repository
yield mock_mem mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {}
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Setup get_all method to return all config values
mock_repo.get_all.return_value = config
# Setup set method to update config values
def set_config(key, value):
config[key] = value
mock_repo.set.side_effect = set_config
# Setup update method to update multiple config values
def update_config(update_dict):
config.update(update_dict)
mock_repo.update.side_effect = update_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
def test_get_model_token_limit_anthropic(mock_memory): def test_get_model_token_limit_anthropic(mock_config_repository):
"""Test get_model_token_limit with Anthropic model.""" """Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"} config = {"provider": "anthropic", "model": "claude2"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default") token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
def test_get_model_token_limit_openai(mock_memory): def test_get_model_token_limit_openai(mock_config_repository):
"""Test get_model_token_limit with OpenAI model.""" """Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"} config = {"provider": "openai", "model": "gpt-4"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default") token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
def test_get_model_token_limit_unknown(mock_memory): def test_get_model_token_limit_unknown(mock_config_repository):
"""Test get_model_token_limit with unknown provider/model.""" """Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"} config = {"provider": "unknown", "model": "unknown-model"}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default") token_limit = get_model_token_limit(config, "default")
assert token_limit is None assert token_limit is None
def test_get_model_token_limit_missing_config(mock_memory): def test_get_model_token_limit_missing_config(mock_config_repository):
"""Test get_model_token_limit with missing configuration.""" """Test get_model_token_limit with missing configuration."""
config = {} config = {}
mock_config_repository.update(config)
token_limit = get_model_token_limit(config, "default") token_limit = get_model_token_limit(config, "default")
assert token_limit is None assert token_limit is None
@ -108,9 +139,9 @@ def test_get_model_token_limit_unexpected_error():
assert token_limit is None assert token_limit is None
def test_create_agent_anthropic(mock_model, mock_memory): def test_create_agent_anthropic(mock_model, mock_config_repository):
"""Test create_agent with Anthropic Claude model.""" """Test create_agent with Anthropic Claude model."""
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"} mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
with patch("ra_aid.agent_utils.create_react_agent") as mock_react: with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
mock_react.return_value = "react_agent" mock_react.return_value = "react_agent"
@ -125,9 +156,9 @@ def test_create_agent_anthropic(mock_model, mock_memory):
) )
def test_create_agent_openai(mock_model, mock_memory): def test_create_agent_openai(mock_model, mock_config_repository):
"""Test create_agent with OpenAI model.""" """Test create_agent with OpenAI model."""
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"} mock_config_repository.update({"provider": "openai", "model": "gpt-4"})
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent" mock_ciayn.return_value = "ciayn_agent"
@ -142,9 +173,9 @@ def test_create_agent_openai(mock_model, mock_memory):
) )
def test_create_agent_no_token_limit(mock_model, mock_memory): def test_create_agent_no_token_limit(mock_model, mock_config_repository):
"""Test create_agent when no token limit is found.""" """Test create_agent when no token limit is found."""
mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"} mock_config_repository.update({"provider": "unknown", "model": "unknown-model"})
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent" mock_ciayn.return_value = "ciayn_agent"
@ -159,9 +190,9 @@ def test_create_agent_no_token_limit(mock_model, mock_memory):
) )
def test_create_agent_missing_config(mock_model, mock_memory): def test_create_agent_missing_config(mock_model, mock_config_repository):
"""Test create_agent with missing configuration.""" """Test create_agent with missing configuration."""
mock_memory.get.return_value = {"provider": "openai"} mock_config_repository.update({"provider": "openai"})
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent" mock_ciayn.return_value = "ciayn_agent"
@ -205,9 +236,9 @@ def test_state_modifier(mock_messages):
assert result[-1] == mock_messages[-1] assert result[-1] == mock_messages[-1]
def test_create_agent_with_checkpointer(mock_model, mock_memory): def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
"""Test create_agent with checkpointer argument.""" """Test create_agent with checkpointer argument."""
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"} mock_config_repository.update({"provider": "openai", "model": "gpt-4"})
mock_checkpointer = Mock() mock_checkpointer = Mock()
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
@ -223,13 +254,13 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory):
) )
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory): def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
"""Test create_agent sets up token limiting for Claude models when enabled.""" """Test create_agent sets up token limiting for Claude models when enabled."""
mock_memory.get.return_value = { mock_config_repository.update({
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2", "model": "claude-2",
"limit_tokens": True, "limit_tokens": True,
} })
with ( with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react, patch("ra_aid.agent_utils.create_react_agent") as mock_react,
@ -246,13 +277,13 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
assert callable(args[1]["state_modifier"]) assert callable(args[1]["state_modifier"])
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory): def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
"""Test create_agent doesn't set up token limiting for Claude models when disabled.""" """Test create_agent doesn't set up token limiting for Claude models when disabled."""
mock_memory.get.return_value = { mock_config_repository.update({
"provider": "anthropic", "provider": "anthropic",
"model": "claude-2", "model": "claude-2",
"limit_tokens": False, "limit_tokens": False,
} })
with ( with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react, patch("ra_aid.agent_utils.create_react_agent") as mock_react,
@ -267,7 +298,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
mock_react.assert_called_once_with(mock_model, [], version="v2") mock_react.assert_called_once_with(mock_model, [], version="v2")
def test_get_model_token_limit_research(mock_memory): def test_get_model_token_limit_research(mock_config_repository):
"""Test get_model_token_limit with research provider and model.""" """Test get_model_token_limit with research provider and model."""
config = { config = {
"provider": "openai", "provider": "openai",
@ -275,13 +306,15 @@ def test_get_model_token_limit_research(mock_memory):
"research_provider": "anthropic", "research_provider": "anthropic",
"research_model": "claude-2", "research_model": "claude-2",
} }
mock_config_repository.update(config)
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 150000} mock_get_info.return_value = {"max_input_tokens": 150000}
token_limit = get_model_token_limit(config, "research") token_limit = get_model_token_limit(config, "research")
assert token_limit == 150000 assert token_limit == 150000
def test_get_model_token_limit_planner(mock_memory): def test_get_model_token_limit_planner(mock_config_repository):
"""Test get_model_token_limit with planner provider and model.""" """Test get_model_token_limit with planner provider and model."""
config = { config = {
"provider": "openai", "provider": "openai",
@ -289,6 +322,8 @@ def test_get_model_token_limit_planner(mock_memory):
"planner_provider": "deepseek", "planner_provider": "deepseek",
"planner_model": "dsm-1", "planner_model": "dsm-1",
} }
mock_config_repository.update(config)
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 120000} mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner") token_limit = get_model_token_limit(config, "planner")

View File

@ -1,41 +1,35 @@
"""Tests for the is_informational_query and is_stage_requested functions.""" """Tests for the is_informational_query and is_stage_requested functions."""
import pytest
from ra_aid.__main__ import is_informational_query, is_stage_requested from ra_aid.__main__ import is_informational_query, is_stage_requested
from ra_aid.tools.memory import _global_memory from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
def test_is_informational_query(): @pytest.fixture
def config_repo():
"""Fixture for config repository."""
with ConfigRepositoryManager() as repo:
yield repo
def test_is_informational_query(config_repo):
"""Test that is_informational_query only depends on research_only config setting.""" """Test that is_informational_query only depends on research_only config setting."""
# Clear global memory to ensure clean state
_global_memory.clear()
# When research_only is True, should return True # When research_only is True, should return True
_global_memory["config"] = {"research_only": True} config_repo.set("research_only", True)
assert is_informational_query() is True assert is_informational_query() is True
# When research_only is False, should return False # When research_only is False, should return False
_global_memory["config"] = {"research_only": False} config_repo.set("research_only", False)
assert is_informational_query() is False assert is_informational_query() is False
# When config is empty, should return False (default) # When config is empty, should return False (default)
_global_memory.clear() config_repo.update({})
_global_memory["config"] = {}
assert is_informational_query() is False
# When global memory is empty, should return False (default)
_global_memory.clear()
assert is_informational_query() is False assert is_informational_query() is False
def test_is_stage_requested(): def test_is_stage_requested():
"""Test that is_stage_requested always returns False now.""" """Test that is_stage_requested always returns False now."""
# Clear global memory to ensure clean state
_global_memory.clear()
# Should always return False regardless of input # Should always return False regardless of input
assert is_stage_requested("implementation") is False assert is_stage_requested("implementation") is False
assert is_stage_requested("anything_else") is False assert is_stage_requested("anything_else") is False
# Even if we set implementation_requested in global memory
_global_memory["implementation_requested"] = True
assert is_stage_requested("implementation") is False

View File

@ -7,14 +7,50 @@ from ra_aid.__main__ import parse_arguments
from ra_aid.config import DEFAULT_RECURSION_LIMIT from ra_aid.config import DEFAULT_RECURSION_LIMIT
from ra_aid.tools.memory import _global_memory from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.work_log_repository import WorkLogEntry from ra_aid.database.repositories.work_log_repository import WorkLogEntry
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository
@pytest.fixture(autouse=True)
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {}
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Setup set method to update config values
def set_config(key, value):
config[key] = value
mock_repo.set.side_effect = set_config
# Setup update method to update multiple config values
def update_config(config_dict):
config.update(config_dict)
mock_repo.update.side_effect = update_config
# Setup get_all method to return the config dict
def get_all_config():
return config.copy()
mock_repo.get_all.side_effect = get_all_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture @pytest.fixture
def mock_dependencies(monkeypatch): def mock_dependencies(monkeypatch):
"""Mock all dependencies needed for main().""" """Mock all dependencies needed for main()."""
# Initialize global memory with necessary keys to prevent KeyError # Initialize global memory
_global_memory.clear() _global_memory.clear()
_global_memory["config"] = {}
# Mock dependencies that interact with external systems # Mock dependencies that interact with external systems
monkeypatch.setattr("ra_aid.__main__.check_dependencies", lambda: None) monkeypatch.setattr("ra_aid.__main__.check_dependencies", lambda: None)
@ -26,10 +62,9 @@ def mock_dependencies(monkeypatch):
# Mock LLM initialization # Mock LLM initialization
def mock_config_update(*args, **kwargs): def mock_config_update(*args, **kwargs):
config = _global_memory.get("config", {}) config_repo = get_config_repository()
if kwargs.get("temperature"): if kwargs.get("temperature"):
config["temperature"] = kwargs["temperature"] config_repo.set("temperature", kwargs["temperature"])
_global_memory["config"] = config
return None return None
monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update) monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update)
@ -107,26 +142,52 @@ def mock_work_log_repository():
yield mock_repo yield mock_repo
def test_recursion_limit_in_global_config(mock_dependencies): def test_recursion_limit_in_global_config(mock_dependencies, mock_config_repository):
"""Test that recursion limit is correctly set in global config.""" """Test that recursion limit is correctly set in global config."""
import sys import sys
from unittest.mock import patch from unittest.mock import patch
from ra_aid.__main__ import main from ra_aid.__main__ import main
_global_memory.clear() # Clear the mock repository before each test
mock_config_repository.update.reset_mock()
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
# Test default recursion limit
with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]):
main()
# Check that the recursion_limit value was included in the update call
mock_config_repository.update.assert_called()
# Get the call arguments
call_args = mock_config_repository.update.call_args_list
# Find the call that includes recursion_limit
recursion_limit_found = False
for args, _ in call_args:
config_dict = args[0]
if "recursion_limit" in config_dict and config_dict["recursion_limit"] == DEFAULT_RECURSION_LIMIT:
recursion_limit_found = True
break
assert recursion_limit_found, f"recursion_limit not found in update calls: {call_args}"
with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]): # Reset mock to clear call history
main() mock_config_repository.update.reset_mock()
assert _global_memory["config"]["recursion_limit"] == DEFAULT_RECURSION_LIMIT
# Test custom recursion limit
_global_memory.clear() with patch.object(sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"]):
main()
with patch.object( # Check that the recursion_limit value was included in the update call
sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"] mock_config_repository.update.assert_called()
): # Get the call arguments
main() call_args = mock_config_repository.update.call_args_list
assert _global_memory["config"]["recursion_limit"] == 50 # Find the call that includes recursion_limit with value 50
recursion_limit_found = False
for args, _ in call_args:
config_dict = args[0]
if "recursion_limit" in config_dict and config_dict["recursion_limit"] == 50:
recursion_limit_found = True
break
assert recursion_limit_found, f"recursion_limit=50 not found in update calls: {call_args}"
def test_negative_recursion_limit(): def test_negative_recursion_limit():
@ -141,70 +202,83 @@ def test_zero_recursion_limit():
parse_arguments(["-m", "test message", "--recursion-limit", "0"]) parse_arguments(["-m", "test message", "--recursion-limit", "0"])
def test_config_settings(mock_dependencies): def test_config_settings(mock_dependencies, mock_config_repository):
"""Test that various settings are correctly applied in global config.""" """Test that various settings are correctly applied in global config."""
import sys import sys
from unittest.mock import patch from unittest.mock import patch
from ra_aid.__main__ import main from ra_aid.__main__ import main
_global_memory.clear() # Clear the mock repository before each test
mock_config_repository.update.reset_mock()
with patch.object(
sys, # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
"argv", with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
[ with patch.object(
"ra-aid", sys,
"-m", "argv",
"test message", [
"--cowboy-mode", "ra-aid",
"--research-only", "-m",
"--provider", "test message",
"anthropic", "--cowboy-mode",
"--model", "--research-only",
"claude-3-7-sonnet-20250219", "--provider",
"--expert-provider", "anthropic",
"openai", "--model",
"--expert-model", "claude-3-7-sonnet-20250219",
"gpt-4", "--expert-provider",
"--temperature", "openai",
"0.7", "--expert-model",
"--disable-limit-tokens", "gpt-4",
], "--temperature",
): "0.7",
main() "--disable-limit-tokens",
config = _global_memory["config"] ],
assert config["cowboy_mode"] is True ):
assert config["research_only"] is True main()
assert config["provider"] == "anthropic" # Verify config values are set via the update method
assert config["model"] == "claude-3-7-sonnet-20250219" mock_config_repository.update.assert_called()
assert config["expert_provider"] == "openai" # Get the call arguments
assert config["expert_model"] == "gpt-4" call_args = mock_config_repository.update.call_args_list
assert config["limit_tokens"] is False
# Check for config values in the update calls
for args, _ in call_args:
config_dict = args[0]
if "cowboy_mode" in config_dict:
assert config_dict["cowboy_mode"] is True
if "research_only" in config_dict:
assert config_dict["research_only"] is True
if "limit_tokens" in config_dict:
assert config_dict["limit_tokens"] is False
# Check provider and model settings via set method
mock_config_repository.set.assert_any_call("provider", "anthropic")
mock_config_repository.set.assert_any_call("model", "claude-3-7-sonnet-20250219")
mock_config_repository.set.assert_any_call("expert_provider", "openai")
mock_config_repository.set.assert_any_call("expert_model", "gpt-4")
def test_temperature_validation(mock_dependencies): def test_temperature_validation(mock_dependencies, mock_config_repository):
"""Test that temperature argument is correctly passed to initialize_llm.""" """Test that temperature argument is correctly passed to initialize_llm."""
import sys import sys
from unittest.mock import patch, ANY from unittest.mock import patch, ANY
from ra_aid.__main__ import main from ra_aid.__main__ import main
# Reset global memory for clean test # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
_global_memory.clear() with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
_global_memory["config"] = {} # Test valid temperature (0.7)
with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm:
# Test valid temperature (0.7) # Also patch any calls that would actually use the mocked initialize_llm function
with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm: with patch("ra_aid.__main__.run_research_agent", return_value=None):
# Also patch any calls that would actually use the mocked initialize_llm function with patch("ra_aid.__main__.run_planning_agent", return_value=None):
with patch("ra_aid.__main__.run_research_agent", return_value=None): with patch.object(
with patch("ra_aid.__main__.run_planning_agent", return_value=None): sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"]
with patch.object( ):
sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"] main()
): # Verify that the temperature was set in the config repository
main() mock_config_repository.set.assert_any_call("temperature", 0.7)
# Check if temperature was stored in config correctly
assert _global_memory["config"]["temperature"] == 0.7
# Test invalid temperature (2.1) # Test invalid temperature (2.1)
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
@ -230,61 +304,67 @@ def test_missing_message():
assert args.message == "test" assert args.message == "test"
def test_research_model_provider_args(mock_dependencies): def test_research_model_provider_args(mock_dependencies, mock_config_repository):
"""Test that research-specific model/provider args are correctly stored in config.""" """Test that research-specific model/provider args are correctly stored in config."""
import sys import sys
from unittest.mock import patch from unittest.mock import patch
from ra_aid.__main__ import main from ra_aid.__main__ import main
_global_memory.clear() # Reset mocks
mock_config_repository.set.reset_mock()
with patch.object(
sys, # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
"argv", with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
[ with patch.object(
"ra-aid", sys,
"-m", "argv",
"test message", [
"--research-provider", "ra-aid",
"anthropic", "-m",
"--research-model", "test message",
"claude-3-haiku-20240307", "--research-provider",
"--planner-provider", "anthropic",
"openai", "--research-model",
"--planner-model", "claude-3-haiku-20240307",
"gpt-4", "--planner-provider",
], "openai",
): "--planner-model",
main() "gpt-4",
config = _global_memory["config"] ],
assert config["research_provider"] == "anthropic" ):
assert config["research_model"] == "claude-3-haiku-20240307" main()
assert config["planner_provider"] == "openai" # Verify the mock repo's set method was called with the expected values
assert config["planner_model"] == "gpt-4" mock_config_repository.set.assert_any_call("research_provider", "anthropic")
mock_config_repository.set.assert_any_call("research_model", "claude-3-haiku-20240307")
mock_config_repository.set.assert_any_call("planner_provider", "openai")
mock_config_repository.set.assert_any_call("planner_model", "gpt-4")
def test_planner_model_provider_args(mock_dependencies): def test_planner_model_provider_args(mock_dependencies, mock_config_repository):
"""Test that planner provider/model args fall back to main config when not specified.""" """Test that planner provider/model args fall back to main config when not specified."""
import sys import sys
from unittest.mock import patch from unittest.mock import patch
from ra_aid.__main__ import main from ra_aid.__main__ import main
_global_memory.clear() # Reset mocks
mock_config_repository.set.reset_mock()
with patch.object(
sys, # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
"argv", with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"], with patch.object(
): sys,
main() "argv",
config = _global_memory["config"] ["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"],
assert config["planner_provider"] == "openai" ):
assert config["planner_model"] == "gpt-4" main()
# Verify the mock repo's set method was called with the expected values
mock_config_repository.set.assert_any_call("planner_provider", "openai")
mock_config_repository.set.assert_any_call("planner_model", "gpt-4")
def test_use_aider_flag(mock_dependencies): def test_use_aider_flag(mock_dependencies, mock_config_repository):
"""Test that use-aider flag is correctly stored in config.""" """Test that use-aider flag is correctly stored in config."""
import sys import sys
from unittest.mock import patch from unittest.mock import patch
@ -292,44 +372,68 @@ def test_use_aider_flag(mock_dependencies):
from ra_aid.__main__ import main from ra_aid.__main__ import main
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
_global_memory.clear() # Reset mocks
mock_config_repository.update.reset_mock()
# Reset to default state # Reset to default state
set_modification_tools(False) set_modification_tools(False)
# Check default behavior (use_aider=False) # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
with patch.object( with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
sys, # Check default behavior (use_aider=False)
"argv", with patch.object(
["ra-aid", "-m", "test message"], sys,
): "argv",
main() ["ra-aid", "-m", "test message"],
config = _global_memory["config"] ):
assert config.get("use_aider") is False main()
# Verify use_aider is set to False in the update call
mock_config_repository.update.assert_called()
# Get the call arguments
call_args = mock_config_repository.update.call_args_list
# Find the call that includes use_aider
use_aider_found = False
for args, _ in call_args:
config_dict = args[0]
if "use_aider" in config_dict and config_dict["use_aider"] is False:
use_aider_found = True
break
assert use_aider_found, f"use_aider=False not found in update calls: {call_args}"
# Check that file tools are enabled by default # Check that file tools are enabled by default
tool_names = [tool.name for tool in MODIFICATION_TOOLS] tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" in tool_names assert "file_str_replace" in tool_names
assert "put_complete_file_contents" in tool_names assert "put_complete_file_contents" in tool_names
assert "run_programming_task" not in tool_names assert "run_programming_task" not in tool_names
_global_memory.clear() # Reset mocks
mock_config_repository.update.reset_mock()
# Check with --use-aider flag # Check with --use-aider flag
with patch.object( with patch.object(
sys, sys,
"argv", "argv",
["ra-aid", "-m", "test message", "--use-aider"], ["ra-aid", "-m", "test message", "--use-aider"],
): ):
main() main()
config = _global_memory["config"] # Verify use_aider is set to True in the update call
assert config.get("use_aider") is True mock_config_repository.update.assert_called()
# Get the call arguments
call_args = mock_config_repository.update.call_args_list
# Find the call that includes use_aider
use_aider_found = False
for args, _ in call_args:
config_dict = args[0]
if "use_aider" in config_dict and config_dict["use_aider"] is True:
use_aider_found = True
break
assert use_aider_found, f"use_aider=True not found in update calls: {call_args}"
# Check that run_programming_task is enabled # Check that run_programming_task is enabled
tool_names = [tool.name for tool in MODIFICATION_TOOLS] tool_names = [tool.name for tool in MODIFICATION_TOOLS]
assert "file_str_replace" not in tool_names assert "file_str_replace" not in tool_names
assert "put_complete_file_contents" not in tool_names assert "put_complete_file_contents" not in tool_names
assert "run_programming_task" in tool_names assert "run_programming_task" in tool_names
# Reset to default state for other tests # Reset to default state for other tests
set_modification_tools(False) set_modification_tools(False)

View File

@ -1,3 +1,4 @@
import os
import pytest import pytest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
@ -7,6 +8,36 @@ from ra_aid.tools.programmer import (
run_programming_task, run_programming_task,
) )
from ra_aid.database.repositories.related_files_repository import get_related_files_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.config_repository import get_config_repository
@pytest.fixture(autouse=True)
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {
"recursion_limit": 2,
"provider": "anthropic",
"model": "claude-3-5-sonnet-20241022",
"temperature": 0.01,
"aider_config": "/path/to/config.yml"
}
# Setup get_all method to return the config dict
mock_repo.get_all.return_value = config
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_related_files_repository(): def mock_related_files_repository():
@ -125,13 +156,9 @@ def test_parse_aider_flags(input_flags, expected, description):
assert result == expected, f"Failed test case: {description}" assert result == expected, f"Failed test case: {description}"
def test_aider_config_flag(mocker, mock_related_files_repository): def test_aider_config_flag(mocker, mock_config_repository, mock_related_files_repository):
"""Test that aider config flag is properly included in the command when specified.""" """Test that aider config flag is properly included in the command when specified."""
# Mock config in global memory but not related files (using repository now) # Config is mocked by mock_config_repository fixture
mock_memory = {
"config": {"aider_config": "/path/to/config.yml"},
}
mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory)
# Mock the run_interactive_command to capture the command that would be run # Mock the run_interactive_command to capture the command that would be run
mock_run = mocker.patch( mock_run = mocker.patch(
@ -146,15 +173,14 @@ def test_aider_config_flag(mocker, mock_related_files_repository):
assert args[config_index + 1] == "/path/to/config.yml" assert args[config_index + 1] == "/path/to/config.yml"
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_related_files_repository): def test_path_normalization_and_deduplication(mocker, tmp_path, mock_config_repository, mock_related_files_repository):
"""Test path normalization and deduplication in run_programming_task.""" """Test path normalization and deduplication in run_programming_task."""
# Create a temporary test file # Create a temporary test file
test_file = tmp_path / "test.py" test_file = tmp_path / "test.py"
test_file.write_text("") test_file.write_text("")
new_file = tmp_path / "new.py" new_file = tmp_path / "new.py"
# Mock dependencies - only need to mock config part of global memory now # Config is mocked by mock_config_repository fixture
mocker.patch("ra_aid.tools.programmer._global_memory", {"config": {}})
mocker.patch( mocker.patch(
"ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider" "ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider"
) )

View File

@ -13,6 +13,7 @@ from ra_aid.tools.agent import (
from ra_aid.tools.memory import _global_memory from ra_aid.tools.memory import _global_memory
from ra_aid.database.repositories.related_files_repository import get_related_files_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry
from ra_aid.database.repositories.config_repository import get_config_repository
@pytest.fixture @pytest.fixture
@ -43,6 +44,34 @@ def mock_related_files_repository():
yield mock_repo yield mock_repo
@pytest.fixture(autouse=True)
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {
"recursion_limit": 2,
"provider": "anthropic",
"model": "claude-3-5-sonnet-20241022",
"temperature": 0.01
}
# Setup get_all method to return the config dict
mock_repo.get_all.return_value = config
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_work_log_repository(): def mock_work_log_repository():
"""Mock the WorkLogRepository to avoid database operations during tests""" """Mock the WorkLogRepository to avoid database operations during tests"""

View File

@ -1,8 +1,8 @@
from unittest.mock import patch from unittest.mock import patch, MagicMock
import pytest import pytest
from ra_aid.tools.memory import _global_memory from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
from ra_aid.tools.shell import run_shell_command from ra_aid.tools.shell import run_shell_command
@ -25,9 +25,38 @@ def mock_run_interactive():
yield mock yield mock
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive): @pytest.fixture(autouse=True)
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Create a dictionary to simulate config
config = {
"cowboy_mode": False
}
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Setup set method to update config values
def set_config(key, value):
config[key] = value
mock_repo.set.side_effect = set_config
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
"""Test shell command execution in cowboy mode (no approval)""" """Test shell command execution in cowboy mode (no approval)"""
_global_memory["config"] = {"cowboy_mode": True} # Set cowboy mode to True using the repository
mock_config_repository.set("cowboy_mode", True)
result = run_shell_command.invoke({"command": "echo test"}) result = run_shell_command.invoke({"command": "echo test"})
@ -37,9 +66,10 @@ def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interacti
mock_prompt.ask.assert_not_called() mock_prompt.ask.assert_not_called()
def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive): def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
"""Test that cowboy mode displays a properly formatted cowboy message with correct spacing""" """Test that cowboy mode displays a properly formatted cowboy message with correct spacing"""
_global_memory["config"] = {"cowboy_mode": True} # Set cowboy mode to True using the repository
mock_config_repository.set("cowboy_mode", True)
with patch("ra_aid.tools.shell.get_cowboy_message") as mock_get_message: with patch("ra_aid.tools.shell.get_cowboy_message") as mock_get_message:
mock_get_message.return_value = "🤠 Test cowboy message!" mock_get_message.return_value = "🤠 Test cowboy message!"
@ -53,10 +83,11 @@ def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_intera
def test_shell_command_interactive_approved( def test_shell_command_interactive_approved(
mock_console, mock_prompt, mock_run_interactive mock_console, mock_prompt, mock_run_interactive, mock_config_repository
): ):
"""Test shell command execution with interactive approval""" """Test shell command execution with interactive approval"""
_global_memory["config"] = {"cowboy_mode": False} # Set cowboy mode to False using the repository
mock_config_repository.set("cowboy_mode", False)
mock_prompt.ask.return_value = "y" mock_prompt.ask.return_value = "y"
result = run_shell_command.invoke({"command": "echo test"}) result = run_shell_command.invoke({"command": "echo test"})
@ -74,10 +105,11 @@ def test_shell_command_interactive_approved(
def test_shell_command_interactive_rejected( def test_shell_command_interactive_rejected(
mock_console, mock_prompt, mock_run_interactive mock_console, mock_prompt, mock_run_interactive, mock_config_repository
): ):
"""Test shell command rejection in interactive mode""" """Test shell command rejection in interactive mode"""
_global_memory["config"] = {"cowboy_mode": False} # Set cowboy mode to False using the repository
mock_config_repository.set("cowboy_mode", False)
mock_prompt.ask.return_value = "n" mock_prompt.ask.return_value = "n"
result = run_shell_command.invoke({"command": "echo test"}) result = run_shell_command.invoke({"command": "echo test"})
@ -95,13 +127,14 @@ def test_shell_command_interactive_rejected(
mock_run_interactive.assert_not_called() mock_run_interactive.assert_not_called()
def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive): def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
"""Test handling of shell command execution errors""" """Test handling of shell command execution errors"""
_global_memory["config"] = {"cowboy_mode": True} # Set cowboy mode to True using the repository
mock_config_repository.set("cowboy_mode", True)
mock_run_interactive.side_effect = Exception("Command failed") mock_run_interactive.side_effect = Exception("Command failed")
result = run_shell_command.invoke({"command": "invalid command"}) result = run_shell_command.invoke({"command": "invalid command"})
assert result["success"] is False assert result["success"] is False
assert result["return_code"] == 1 assert result["return_code"] == 1
assert "Command failed" in result["output"] assert "Command failed" in result["output"]