From 5bd8c76a222e8021a34ee4740744cab9797d0d08 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Tue, 4 Mar 2025 21:01:08 -0500 Subject: [PATCH] config repository --- ra_aid/__main__.py | 59 +-- ra_aid/agent_utils.py | 16 +- ra_aid/agents/key_facts_gc_agent.py | 5 +- ra_aid/agents/key_snippets_gc_agent.py | 5 +- ra_aid/agents/research_notes_gc_agent.py | 5 +- .../repositories/config_repository.py | 165 ++++++++ ra_aid/tool_configs.py | 29 +- ra_aid/tools/agent.py | 12 +- ra_aid/tools/expert.py | 9 +- ra_aid/tools/human.py | 4 +- ra_aid/tools/programmer.py | 13 +- ra_aid/tools/shell.py | 7 +- tests/conftest.py | 15 + tests/ra_aid/test_agent_utils.py | 91 +++-- tests/ra_aid/test_info_query.py | 36 +- tests/ra_aid/test_main.py | 382 +++++++++++------- tests/ra_aid/test_programmer.py | 44 +- tests/ra_aid/tools/test_agent.py | 29 ++ tests/ra_aid/tools/test_shell.py | 59 ++- 19 files changed, 695 insertions(+), 290 deletions(-) create mode 100644 ra_aid/database/repositories/config_repository.py diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index d18c90a..3d02851 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -58,6 +58,10 @@ from ra_aid.database.repositories.related_files_repository import ( from ra_aid.database.repositories.work_log_repository import ( WorkLogRepositoryManager ) +from ra_aid.database.repositories.config_repository import ( + ConfigRepositoryManager, + get_config_repository +) from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.console.output import cpm @@ -77,7 +81,7 @@ from ra_aid.prompts.chat_prompts import CHAT_PROMPT from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_CHAT from ra_aid.tool_configs import get_chat_tools, set_modification_tools from ra_aid.tools.human import ask_human -from ra_aid.tools.memory import _global_memory, get_memory_value +from ra_aid.tools.memory import get_memory_value logger = get_logger(__name__) @@ -338,7 +342,7 @@ implementation_memory = MemorySaver() def is_informational_query() -> bool: """Determine if the current query is informational based on config settings.""" - return _global_memory.get("config", {}).get("research_only", False) + return get_config_repository().get("research_only", False) def is_stage_requested(stage: str) -> bool: @@ -404,13 +408,17 @@ def main(): except Exception as e: logger.error(f"Database migration error: {str(e)}") + # Initialize empty config dictionary to be populated later + config = {} + # Initialize repositories with database connection with KeyFactRepositoryManager(db) as key_fact_repo, \ KeySnippetRepositoryManager(db) as key_snippet_repo, \ HumanInputRepositoryManager(db) as human_input_repo, \ ResearchNoteRepositoryManager(db) as research_note_repo, \ RelatedFilesRepositoryManager() as related_files_repo, \ - WorkLogRepositoryManager() as work_log_repo: + WorkLogRepositoryManager() as work_log_repo, \ + ConfigRepositoryManager(config) as config_repo: # This initializes all repositories and makes them available via their respective get methods logger.debug("Initialized KeyFactRepository") logger.debug("Initialized KeySnippetRepository") @@ -418,6 +426,7 @@ def main(): logger.debug("Initialized ResearchNoteRepository") logger.debug("Initialized RelatedFilesRepository") logger.debug("Initialized WorkLogRepository") + logger.debug("Initialized ConfigRepository") # Check dependencies before proceeding check_dependencies() @@ -520,13 +529,13 @@ def main(): "limit_tokens": args.disable_limit_tokens, } - # Store config in global memory - _global_memory["config"] = config - _global_memory["config"]["provider"] = args.provider - _global_memory["config"]["model"] = args.model - _global_memory["config"]["expert_provider"] = args.expert_provider - _global_memory["config"]["expert_model"] = args.expert_model - _global_memory["config"]["temperature"] = args.temperature + # Store config in repository + config_repo.update(config) + config_repo.set("provider", args.provider) + config_repo.set("model", args.model) + config_repo.set("expert_provider", args.expert_provider) + config_repo.set("expert_model", args.expert_model) + config_repo.set("temperature", args.temperature) # Set modification tools based on use_aider flag set_modification_tools(args.use_aider) @@ -594,33 +603,27 @@ def main(): "test_cmd_timeout": args.test_cmd_timeout, } - # Store config in global memory for access by is_informational_query - _global_memory["config"] = config + # Store config in repository + config_repo.update(config) # Store base provider/model configuration - _global_memory["config"]["provider"] = args.provider - _global_memory["config"]["model"] = args.model + config_repo.set("provider", args.provider) + config_repo.set("model", args.model) # Store expert provider/model (no fallback) - _global_memory["config"]["expert_provider"] = args.expert_provider - _global_memory["config"]["expert_model"] = args.expert_model + config_repo.set("expert_provider", args.expert_provider) + config_repo.set("expert_model", args.expert_model) # Store planner config with fallback to base values - _global_memory["config"]["planner_provider"] = ( - args.planner_provider or args.provider - ) - _global_memory["config"]["planner_model"] = args.planner_model or args.model + config_repo.set("planner_provider", args.planner_provider or args.provider) + config_repo.set("planner_model", args.planner_model or args.model) # Store research config with fallback to base values - _global_memory["config"]["research_provider"] = ( - args.research_provider or args.provider - ) - _global_memory["config"]["research_model"] = ( - args.research_model or args.model - ) + config_repo.set("research_provider", args.research_provider or args.provider) + config_repo.set("research_model", args.research_model or args.model) - # Store temperature in global config - _global_memory["config"]["temperature"] = args.temperature + # Store temperature in config + config_repo.set("temperature", args.temperature) # Set modification tools based on use_aider flag set_modification_tools(args.use_aider) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 7ef9efa..3460acb 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -94,11 +94,11 @@ from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict from ra_aid.tools.memory import ( - _global_memory, get_memory_value, get_related_files, log_work_event, ) +from ra_aid.database.repositories.config_repository import get_config_repository console = Console() @@ -302,7 +302,7 @@ def create_agent( config['limit_tokens'] = False. """ try: - config = _global_memory.get("config", {}) + config = get_config_repository().get_all() max_input_tokens = ( get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT ) @@ -319,7 +319,7 @@ def create_agent( except Exception as e: # Default to REACT agent if provider/model detection fails logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") - config = _global_memory.get("config", {}) + config = get_config_repository().get_all() max_input_tokens = get_model_token_limit(config, agent_type) agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) return create_react_agent(model, tools, **agent_kwargs) @@ -443,7 +443,7 @@ def run_research_agent( new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "", ) - config = _global_memory.get("config", {}) if not config else config + config = get_config_repository().get_all() if not config else config recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) run_config = { "configurable": {"thread_id": thread_id}, @@ -575,7 +575,7 @@ def run_web_research_agent( related_files=related_files, ) - config = _global_memory.get("config", {}) if not config else config + config = get_config_repository().get_all() if not config else config recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) run_config = { @@ -709,7 +709,7 @@ def run_planning_agent( ), ) - config = _global_memory.get("config", {}) if not config else config + config = get_config_repository().get_all() if not config else config recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) run_config = { "configurable": {"thread_id": thread_id}, @@ -824,7 +824,7 @@ def run_task_implementation_agent( expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", human_section=( HUMAN_PROMPT_SECTION_IMPLEMENTATION - if _global_memory.get("config", {}).get("hil", False) + if get_config_repository().get("hil", False) else "" ), web_research_section=( @@ -834,7 +834,7 @@ def run_task_implementation_agent( ), ) - config = _global_memory.get("config", {}) if not config else config + config = get_config_repository().get_all() if not config else config recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) run_config = { "configurable": {"thread_id": thread_id}, diff --git a/ra_aid/agents/key_facts_gc_agent.py b/ra_aid/agents/key_facts_gc_agent.py index e8eae2c..ece3c3d 100644 --- a/ra_aid/agents/key_facts_gc_agent.py +++ b/ra_aid/agents/key_facts_gc_agent.py @@ -19,9 +19,10 @@ logger = logging.getLogger(__name__) from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository +from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.llm import initialize_llm from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT -from ra_aid.tools.memory import log_work_event, _global_memory +from ra_aid.tools.memory import log_work_event console = Console() @@ -149,7 +150,7 @@ def run_key_facts_gc_agent() -> None: formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()]) # Retrieve configuration - llm_config = _global_memory.get("config", {}) + llm_config = get_config_repository().get_all() # Initialize the LLM model model = initialize_llm( diff --git a/ra_aid/agents/key_snippets_gc_agent.py b/ra_aid/agents/key_snippets_gc_agent.py index 120d6e6..72e4c38 100644 --- a/ra_aid/agents/key_snippets_gc_agent.py +++ b/ra_aid/agents/key_snippets_gc_agent.py @@ -16,9 +16,10 @@ from rich.panel import Panel from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository +from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.llm import initialize_llm from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT -from ra_aid.tools.memory import log_work_event, _global_memory +from ra_aid.tools.memory import log_work_event console = Console() @@ -153,7 +154,7 @@ def run_key_snippets_gc_agent() -> None: ]) # Retrieve configuration - llm_config = _global_memory.get("config", {}) + llm_config = get_config_repository().get_all() # Initialize the LLM model model = initialize_llm( diff --git a/ra_aid/agents/research_notes_gc_agent.py b/ra_aid/agents/research_notes_gc_agent.py index d992137..9fe168e 100644 --- a/ra_aid/agents/research_notes_gc_agent.py +++ b/ra_aid/agents/research_notes_gc_agent.py @@ -19,9 +19,10 @@ logger = logging.getLogger(__name__) from ra_aid.agent_utils import create_agent, run_agent_with_retry from ra_aid.database.repositories.research_note_repository import get_research_note_repository from ra_aid.database.repositories.human_input_repository import get_human_input_repository +from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.llm import initialize_llm from ra_aid.model_formatters.research_notes_formatter import format_research_note -from ra_aid.tools.memory import log_work_event, _global_memory +from ra_aid.tools.memory import log_work_event console = Console() @@ -154,7 +155,7 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None: formatted_notes = "\n".join([f"Note #{k}: {v}" for k, v in notes_dict.items()]) # Retrieve configuration - llm_config = _global_memory.get("config", {}) + llm_config = get_config_repository().get_all() # Initialize the LLM model model = initialize_llm( diff --git a/ra_aid/database/repositories/config_repository.py b/ra_aid/database/repositories/config_repository.py new file mode 100644 index 0000000..70be4cc --- /dev/null +++ b/ra_aid/database/repositories/config_repository.py @@ -0,0 +1,165 @@ +"""Repository for managing configuration values.""" + +import contextvars +from typing import Any, Dict, Optional + +# Create contextvar to hold the ConfigRepository instance +config_repo_var = contextvars.ContextVar("config_repo", default=None) + + +class ConfigRepository: + """ + Repository for managing configuration values in memory. + + This class provides methods to get, set, update, and retrieve all configuration values. + It does not require database models and operates entirely in memory. + """ + + def __init__(self, initial_config: Optional[Dict[str, Any]] = None): + """ + Initialize the ConfigRepository. + + Args: + initial_config: Optional dictionary of initial configuration values + """ + self._config: Dict[str, Any] = {} + + # Initialize with default values from config.py + from ra_aid.config import ( + DEFAULT_RECURSION_LIMIT, + DEFAULT_MAX_TEST_CMD_RETRIES, + DEFAULT_MAX_TOOL_FAILURES, + FALLBACK_TOOL_MODEL_LIMIT, + RETRY_FALLBACK_COUNT, + DEFAULT_TEST_CMD_TIMEOUT, + VALID_PROVIDERS, + ) + + self._config = { + "recursion_limit": DEFAULT_RECURSION_LIMIT, + "max_test_cmd_retries": DEFAULT_MAX_TEST_CMD_RETRIES, + "max_tool_failures": DEFAULT_MAX_TOOL_FAILURES, + "fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT, + "retry_fallback_count": RETRY_FALLBACK_COUNT, + "test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT, + "valid_providers": VALID_PROVIDERS, + } + + # Update with any provided initial configuration + if initial_config: + self._config.update(initial_config) + + def get(self, key: str, default: Any = None) -> Any: + """ + Get a configuration value by key. + + Args: + key: Configuration key to retrieve + default: Default value to return if key is not found + + Returns: + The configuration value or default if not found + """ + return self._config.get(key, default) + + def set(self, key: str, value: Any) -> None: + """ + Set a configuration value by key. + + Args: + key: Configuration key to set + value: Value to set for the key + """ + self._config[key] = value + + def update(self, config_dict: Dict[str, Any]) -> None: + """ + Update multiple configuration values at once. + + Args: + config_dict: Dictionary of configuration key-value pairs to update + """ + self._config.update(config_dict) + + def get_all(self) -> Dict[str, Any]: + """ + Get all configuration values. + + Returns: + Dictionary containing all configuration values + """ + return self._config.copy() + + +class ConfigRepositoryManager: + """ + Context manager for ConfigRepository. + + This class provides a context manager interface for ConfigRepository, + using the contextvars approach for thread safety. + + Example: + with ConfigRepositoryManager() as repo: + # Use the repository + value = repo.get("key") + repo.set("key", new_value) + """ + + def __init__(self, initial_config: Optional[Dict[str, Any]] = None): + """ + Initialize the ConfigRepositoryManager. + + Args: + initial_config: Optional dictionary of initial configuration values + """ + self.initial_config = initial_config + + def __enter__(self) -> 'ConfigRepository': + """ + Initialize the ConfigRepository and return it. + + Returns: + ConfigRepository: The initialized repository + """ + repo = ConfigRepository(self.initial_config) + config_repo_var.set(repo) + return repo + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[Exception], + exc_tb: Optional[object], + ) -> None: + """ + Reset the repository when exiting the context. + + Args: + exc_type: The exception type if an exception was raised + exc_val: The exception value if an exception was raised + exc_tb: The traceback if an exception was raised + """ + # Reset the contextvar to None + config_repo_var.set(None) + + # Don't suppress exceptions + return False + + +def get_config_repository() -> ConfigRepository: + """ + Get the current ConfigRepository instance. + + Returns: + ConfigRepository: The current repository instance + + Raises: + RuntimeError: If no repository is set in the current context + """ + repo = config_repo_var.get() + if repo is None: + raise RuntimeError( + "ConfigRepository not initialized in current context. " + "Make sure to use ConfigRepositoryManager." + ) + return repo \ No newline at end of file diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index 0de5553..6eba30e 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -27,6 +27,7 @@ from ra_aid.tools.agent import ( request_web_research, ) from ra_aid.tools.memory import plan_implementation_completed +from ra_aid.database.repositories.config_repository import get_config_repository def set_modification_tools(use_aider=False): @@ -98,13 +99,11 @@ def get_all_tools() -> list[BaseTool]: # Define constant tool groups -# Get config from global memory for use_aider value +# Get config from repository for use_aider value _config = {} try: - from ra_aid.tools.memory import _global_memory - - _config = _global_memory.get("config", {}) -except ImportError: + _config = get_config_repository().get_all() +except (ImportError, RuntimeError): pass READ_ONLY_TOOLS = get_read_only_tools(use_aider=_config.get("use_aider", False)) @@ -139,10 +138,8 @@ def get_research_tools( # Get config for use_aider value use_aider = False try: - from ra_aid.tools.memory import _global_memory - - use_aider = _global_memory.get("config", {}).get("use_aider", False) - except ImportError: + use_aider = get_config_repository().get("use_aider", False) + except (ImportError, RuntimeError): pass # Start with read-only tools @@ -180,10 +177,8 @@ def get_planning_tools( # Get config for use_aider value use_aider = False try: - from ra_aid.tools.memory import _global_memory - - use_aider = _global_memory.get("config", {}).get("use_aider", False) - except ImportError: + use_aider = get_config_repository().get("use_aider", False) + except (ImportError, RuntimeError): pass # Start with read-only tools @@ -219,10 +214,8 @@ def get_implementation_tools( # Get config for use_aider value use_aider = False try: - from ra_aid.tools.memory import _global_memory - - use_aider = _global_memory.get("config", {}).get("use_aider", False) - except ImportError: + use_aider = get_config_repository().get("use_aider", False) + except (ImportError, RuntimeError): pass # Start with read-only tools @@ -285,4 +278,4 @@ def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = Fal if web_research_enabled: tools.append(request_web_research) - return tools + return tools \ No newline at end of file diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index a87841f..7e32afb 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -18,13 +18,13 @@ from ra_aid.console.formatting import print_error from ra_aid.database.repositories.human_input_repository import HumanInputRepository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository +from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository from ra_aid.database.repositories.research_note_repository import get_research_note_repository from ra_aid.exceptions import AgentInterrupt from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict -from ra_aid.tools.memory import _global_memory from ..console import print_task_header from ..llm import initialize_llm @@ -52,7 +52,7 @@ def request_research(query: str) -> ResearchResult: query: The research question or project description """ # Initialize model from config - config = _global_memory.get("config", {}) + config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), config.get("model", "claude-3-7-sonnet-20250219"), @@ -165,7 +165,7 @@ def request_web_research(query: str) -> ResearchResult: query: The research question or project description """ # Initialize model from config - config = _global_memory.get("config", {}) + config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), config.get("model", "claude-3-7-sonnet-20250219"), @@ -246,7 +246,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: query: The research question or project description """ # Initialize model from config - config = _global_memory.get("config", {}) + config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), config.get("model", "claude-3-7-sonnet-20250219"), @@ -335,7 +335,7 @@ def request_task_implementation(task_spec: str) -> str: task_spec: REQUIRED The full task specification (markdown format, typically one part of the overall plan) """ # Initialize model from config - config = _global_memory.get("config", {}) + config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), config.get("model", "claude-3-5-sonnet-20241022"), @@ -474,7 +474,7 @@ def request_implementation(task_spec: str) -> str: task_spec: The task specification to plan implementation for """ # Initialize model from config - config = _global_memory.get("config", {}) + config = get_config_repository().get_all() model = initialize_llm( config.get("provider", "anthropic"), config.get("model", "claude-3-5-sonnet-20241022"), diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index 0cfc0f4..3b563d6 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -13,11 +13,12 @@ from ..database.repositories.key_fact_repository import get_key_fact_repository from ..database.repositories.key_snippet_repository import get_key_snippet_repository from ..database.repositories.related_files_repository import get_related_files_repository from ..database.repositories.research_note_repository import get_research_note_repository +from ..database.repositories.config_repository import get_config_repository from ..llm import initialize_expert_llm from ..model_formatters import format_key_facts_dict from ..model_formatters.key_snippets_formatter import format_key_snippets_dict from ..model_formatters.research_notes_formatter import format_research_notes_dict -from .memory import _global_memory, get_memory_value +from .memory import get_memory_value console = Console() _model = None @@ -27,9 +28,9 @@ def get_model(): global _model try: if _model is None: - config = _global_memory["config"] - provider = config.get("expert_provider") or config.get("provider") - model = config.get("expert_model") or config.get("model") + config_repo = get_config_repository() + provider = config_repo.get("expert_provider") or config_repo.get("provider") + model = config_repo.get("expert_model") or config_repo.get("model") _model = initialize_expert_llm(provider, model) except Exception as e: _model = None diff --git a/ra_aid/tools/human.py b/ra_aid/tools/human.py index dc77937..ec275c3 100644 --- a/ra_aid/tools/human.py +++ b/ra_aid/tools/human.py @@ -58,10 +58,10 @@ def ask_human(question: str) -> str: # Record human response in database try: from ra_aid.database.repositories.human_input_repository import get_human_input_repository - from ra_aid.tools.memory import _global_memory + from ra_aid.database.repositories.config_repository import get_config_repository # Determine the source based on context - config = _global_memory.get("config", {}) + config = get_config_repository().get_all() # If chat_mode is enabled, use 'chat', otherwise determine if hil mode is active if config.get("chat_mode", False): source = "chat" diff --git a/ra_aid/tools/programmer.py b/ra_aid/tools/programmer.py index 5b6788b..1a5d99c 100644 --- a/ra_aid/tools/programmer.py +++ b/ra_aid/tools/programmer.py @@ -13,7 +13,8 @@ from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params from ra_aid.proc.interactive import run_interactive_command from ra_aid.text.processing import truncate_output -from ra_aid.tools.memory import _global_memory, log_work_event +from ra_aid.tools.memory import log_work_event +from ra_aid.database.repositories.config_repository import get_config_repository from ra_aid.database.repositories.related_files_repository import get_related_files_repository console = Console() @@ -107,8 +108,9 @@ def run_programming_task( ) # Add config file if specified - if "config" in _global_memory and _global_memory["config"].get("aider_config"): - command.extend(["--config", _global_memory["config"]["aider_config"]]) + config = get_config_repository().get_all() + if config.get("aider_config"): + command.extend(["--config", config["aider_config"]]) # if environment variable AIDER_FLAGS exists then parse if "AIDER_FLAGS" in os.environ: @@ -147,8 +149,9 @@ def run_programming_task( # Run the command interactively print() # Get provider/model specific latency coefficient - provider = _global_memory.get("config", {}).get("provider", "") - model = _global_memory.get("config", {}).get("model", "") + config = get_config_repository().get_all() + provider = config.get("provider", "") + model = config.get("model", "") latency = ( models_params.get(provider, {}) .get(model, {}) diff --git a/ra_aid/tools/shell.py b/ra_aid/tools/shell.py index 48b6247..21733db 100644 --- a/ra_aid/tools/shell.py +++ b/ra_aid/tools/shell.py @@ -9,6 +9,7 @@ from ra_aid.console.cowboy_messages import get_cowboy_message from ra_aid.proc.interactive import run_interactive_command from ra_aid.text.processing import truncate_output from ra_aid.tools.memory import _global_memory, log_work_event +from ra_aid.database.repositories.config_repository import get_config_repository console = Console() @@ -46,7 +47,7 @@ def run_shell_command( 4. Add flags e.g. git --no-pager in order to reduce interaction required by the human. """ # Check if we need approval - cowboy_mode = _global_memory.get("config", {}).get("cowboy_mode", False) + cowboy_mode = get_config_repository().get("cowboy_mode", False) if cowboy_mode: console.print("") @@ -74,7 +75,7 @@ def run_shell_command( "success": False, } elif response == "c": - _global_memory["config"]["cowboy_mode"] = True + get_config_repository().set("cowboy_mode", True) console.print("") console.print(" " + get_cowboy_message()) console.print("") @@ -96,4 +97,4 @@ def run_shell_command( except Exception as e: print() console.print(Panel(str(e), title="❌ Error", border_style="red")) - return {"output": str(e), "return_code": 1, "success": False} + return {"output": str(e), "return_code": 1, "success": False} \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 08fbadb..76c4524 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,10 +7,25 @@ ensuring consistent test environments and proper isolation. import os from pathlib import Path +from unittest.mock import MagicMock import pytest +@pytest.fixture() +def mock_config_repository(): + """Mock the config repository.""" + from ra_aid.database.repositories.config_repository import get_config_repository + + repo = MagicMock() + # Default config values + config_values = {"recursion_limit": 2} + repo.get_all.return_value = config_values + repo.get.side_effect = lambda key, default=None: config_values.get(key, default) + get_config_repository.return_value = repo + yield repo + + @pytest.fixture(autouse=True) def isolated_db_environment(tmp_path, monkeypatch): """ diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 591c0db..3927321 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -1,7 +1,7 @@ """Unit tests for agent_utils.py.""" from typing import Any, Dict, Literal -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock import litellm import pytest @@ -19,6 +19,7 @@ from ra_aid.agent_utils import ( state_modifier, ) from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params +from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var @pytest.fixture @@ -29,40 +30,70 @@ def mock_model(): @pytest.fixture -def mock_memory(): - """Fixture providing a mock global memory store.""" - with patch("ra_aid.agent_utils._global_memory") as mock_mem: - mock_mem.get.return_value = {} - yield mock_mem +def mock_config_repository(): + """Mock the ConfigRepository to avoid database operations during tests""" + with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var: + # Setup a mock repository + mock_repo = MagicMock() + + # Create a dictionary to simulate config + config = {} + + # Setup get method to return config values + def get_config(key, default=None): + return config.get(key, default) + mock_repo.get.side_effect = get_config + + # Setup get_all method to return all config values + mock_repo.get_all.return_value = config + + # Setup set method to update config values + def set_config(key, value): + config[key] = value + mock_repo.set.side_effect = set_config + + # Setup update method to update multiple config values + def update_config(update_dict): + config.update(update_dict) + mock_repo.update.side_effect = update_config + + # Make the mock context var return our mock repo + mock_repo_var.get.return_value = mock_repo + + yield mock_repo -def test_get_model_token_limit_anthropic(mock_memory): +def test_get_model_token_limit_anthropic(mock_config_repository): """Test get_model_token_limit with Anthropic model.""" config = {"provider": "anthropic", "model": "claude2"} + mock_config_repository.update(config) token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] -def test_get_model_token_limit_openai(mock_memory): +def test_get_model_token_limit_openai(mock_config_repository): """Test get_model_token_limit with OpenAI model.""" config = {"provider": "openai", "model": "gpt-4"} + mock_config_repository.update(config) token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] -def test_get_model_token_limit_unknown(mock_memory): +def test_get_model_token_limit_unknown(mock_config_repository): """Test get_model_token_limit with unknown provider/model.""" config = {"provider": "unknown", "model": "unknown-model"} + mock_config_repository.update(config) token_limit = get_model_token_limit(config, "default") assert token_limit is None -def test_get_model_token_limit_missing_config(mock_memory): +def test_get_model_token_limit_missing_config(mock_config_repository): """Test get_model_token_limit with missing configuration.""" config = {} + mock_config_repository.update(config) token_limit = get_model_token_limit(config, "default") assert token_limit is None @@ -108,9 +139,9 @@ def test_get_model_token_limit_unexpected_error(): assert token_limit is None -def test_create_agent_anthropic(mock_model, mock_memory): +def test_create_agent_anthropic(mock_model, mock_config_repository): """Test create_agent with Anthropic Claude model.""" - mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"} + mock_config_repository.update({"provider": "anthropic", "model": "claude-2"}) with patch("ra_aid.agent_utils.create_react_agent") as mock_react: mock_react.return_value = "react_agent" @@ -125,9 +156,9 @@ def test_create_agent_anthropic(mock_model, mock_memory): ) -def test_create_agent_openai(mock_model, mock_memory): +def test_create_agent_openai(mock_model, mock_config_repository): """Test create_agent with OpenAI model.""" - mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"} + mock_config_repository.update({"provider": "openai", "model": "gpt-4"}) with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: mock_ciayn.return_value = "ciayn_agent" @@ -142,9 +173,9 @@ def test_create_agent_openai(mock_model, mock_memory): ) -def test_create_agent_no_token_limit(mock_model, mock_memory): +def test_create_agent_no_token_limit(mock_model, mock_config_repository): """Test create_agent when no token limit is found.""" - mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"} + mock_config_repository.update({"provider": "unknown", "model": "unknown-model"}) with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: mock_ciayn.return_value = "ciayn_agent" @@ -159,9 +190,9 @@ def test_create_agent_no_token_limit(mock_model, mock_memory): ) -def test_create_agent_missing_config(mock_model, mock_memory): +def test_create_agent_missing_config(mock_model, mock_config_repository): """Test create_agent with missing configuration.""" - mock_memory.get.return_value = {"provider": "openai"} + mock_config_repository.update({"provider": "openai"}) with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: mock_ciayn.return_value = "ciayn_agent" @@ -205,9 +236,9 @@ def test_state_modifier(mock_messages): assert result[-1] == mock_messages[-1] -def test_create_agent_with_checkpointer(mock_model, mock_memory): +def test_create_agent_with_checkpointer(mock_model, mock_config_repository): """Test create_agent with checkpointer argument.""" - mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"} + mock_config_repository.update({"provider": "openai", "model": "gpt-4"}) mock_checkpointer = Mock() with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn: @@ -223,13 +254,13 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory): ) -def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory): +def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository): """Test create_agent sets up token limiting for Claude models when enabled.""" - mock_memory.get.return_value = { + mock_config_repository.update({ "provider": "anthropic", "model": "claude-2", "limit_tokens": True, - } + }) with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, @@ -246,13 +277,13 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory): assert callable(args[1]["state_modifier"]) -def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory): +def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository): """Test create_agent doesn't set up token limiting for Claude models when disabled.""" - mock_memory.get.return_value = { + mock_config_repository.update({ "provider": "anthropic", "model": "claude-2", "limit_tokens": False, - } + }) with ( patch("ra_aid.agent_utils.create_react_agent") as mock_react, @@ -267,7 +298,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory) mock_react.assert_called_once_with(mock_model, [], version="v2") -def test_get_model_token_limit_research(mock_memory): +def test_get_model_token_limit_research(mock_config_repository): """Test get_model_token_limit with research provider and model.""" config = { "provider": "openai", @@ -275,13 +306,15 @@ def test_get_model_token_limit_research(mock_memory): "research_provider": "anthropic", "research_model": "claude-2", } + mock_config_repository.update(config) + with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 150000} token_limit = get_model_token_limit(config, "research") assert token_limit == 150000 -def test_get_model_token_limit_planner(mock_memory): +def test_get_model_token_limit_planner(mock_config_repository): """Test get_model_token_limit with planner provider and model.""" config = { "provider": "openai", @@ -289,6 +322,8 @@ def test_get_model_token_limit_planner(mock_memory): "planner_provider": "deepseek", "planner_model": "dsm-1", } + mock_config_repository.update(config) + with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 120000} token_limit = get_model_token_limit(config, "planner") diff --git a/tests/ra_aid/test_info_query.py b/tests/ra_aid/test_info_query.py index d54f443..c8b8ef2 100644 --- a/tests/ra_aid/test_info_query.py +++ b/tests/ra_aid/test_info_query.py @@ -1,41 +1,35 @@ """Tests for the is_informational_query and is_stage_requested functions.""" +import pytest + from ra_aid.__main__ import is_informational_query, is_stage_requested -from ra_aid.tools.memory import _global_memory +from ra_aid.database.repositories.config_repository import ConfigRepositoryManager -def test_is_informational_query(): +@pytest.fixture +def config_repo(): + """Fixture for config repository.""" + with ConfigRepositoryManager() as repo: + yield repo + + +def test_is_informational_query(config_repo): """Test that is_informational_query only depends on research_only config setting.""" - # Clear global memory to ensure clean state - _global_memory.clear() - # When research_only is True, should return True - _global_memory["config"] = {"research_only": True} + config_repo.set("research_only", True) assert is_informational_query() is True # When research_only is False, should return False - _global_memory["config"] = {"research_only": False} + config_repo.set("research_only", False) assert is_informational_query() is False # When config is empty, should return False (default) - _global_memory.clear() - _global_memory["config"] = {} - assert is_informational_query() is False - - # When global memory is empty, should return False (default) - _global_memory.clear() + config_repo.update({}) assert is_informational_query() is False def test_is_stage_requested(): """Test that is_stage_requested always returns False now.""" - # Clear global memory to ensure clean state - _global_memory.clear() - # Should always return False regardless of input assert is_stage_requested("implementation") is False - assert is_stage_requested("anything_else") is False - - # Even if we set implementation_requested in global memory - _global_memory["implementation_requested"] = True - assert is_stage_requested("implementation") is False \ No newline at end of file + assert is_stage_requested("anything_else") is False \ No newline at end of file diff --git a/tests/ra_aid/test_main.py b/tests/ra_aid/test_main.py index 2787b8f..e7a1846 100644 --- a/tests/ra_aid/test_main.py +++ b/tests/ra_aid/test_main.py @@ -7,14 +7,50 @@ from ra_aid.__main__ import parse_arguments from ra_aid.config import DEFAULT_RECURSION_LIMIT from ra_aid.tools.memory import _global_memory from ra_aid.database.repositories.work_log_repository import WorkLogEntry +from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository + + +@pytest.fixture(autouse=True) +def mock_config_repository(): + """Mock the ConfigRepository to avoid database operations during tests""" + with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var: + # Setup a mock repository + mock_repo = MagicMock() + + # Create a dictionary to simulate config + config = {} + + # Setup get method to return config values + def get_config(key, default=None): + return config.get(key, default) + mock_repo.get.side_effect = get_config + + # Setup set method to update config values + def set_config(key, value): + config[key] = value + mock_repo.set.side_effect = set_config + + # Setup update method to update multiple config values + def update_config(config_dict): + config.update(config_dict) + mock_repo.update.side_effect = update_config + + # Setup get_all method to return the config dict + def get_all_config(): + return config.copy() + mock_repo.get_all.side_effect = get_all_config + + # Make the mock context var return our mock repo + mock_repo_var.get.return_value = mock_repo + + yield mock_repo @pytest.fixture def mock_dependencies(monkeypatch): """Mock all dependencies needed for main().""" - # Initialize global memory with necessary keys to prevent KeyError + # Initialize global memory _global_memory.clear() - _global_memory["config"] = {} # Mock dependencies that interact with external systems monkeypatch.setattr("ra_aid.__main__.check_dependencies", lambda: None) @@ -26,10 +62,9 @@ def mock_dependencies(monkeypatch): # Mock LLM initialization def mock_config_update(*args, **kwargs): - config = _global_memory.get("config", {}) + config_repo = get_config_repository() if kwargs.get("temperature"): - config["temperature"] = kwargs["temperature"] - _global_memory["config"] = config + config_repo.set("temperature", kwargs["temperature"]) return None monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update) @@ -107,26 +142,52 @@ def mock_work_log_repository(): yield mock_repo -def test_recursion_limit_in_global_config(mock_dependencies): +def test_recursion_limit_in_global_config(mock_dependencies, mock_config_repository): """Test that recursion limit is correctly set in global config.""" import sys from unittest.mock import patch from ra_aid.__main__ import main - _global_memory.clear() + # Clear the mock repository before each test + mock_config_repository.update.reset_mock() + + # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock + with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository): + # Test default recursion limit + with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]): + main() + # Check that the recursion_limit value was included in the update call + mock_config_repository.update.assert_called() + # Get the call arguments + call_args = mock_config_repository.update.call_args_list + # Find the call that includes recursion_limit + recursion_limit_found = False + for args, _ in call_args: + config_dict = args[0] + if "recursion_limit" in config_dict and config_dict["recursion_limit"] == DEFAULT_RECURSION_LIMIT: + recursion_limit_found = True + break + assert recursion_limit_found, f"recursion_limit not found in update calls: {call_args}" - with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]): - main() - assert _global_memory["config"]["recursion_limit"] == DEFAULT_RECURSION_LIMIT - - _global_memory.clear() - - with patch.object( - sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"] - ): - main() - assert _global_memory["config"]["recursion_limit"] == 50 + # Reset mock to clear call history + mock_config_repository.update.reset_mock() + + # Test custom recursion limit + with patch.object(sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"]): + main() + # Check that the recursion_limit value was included in the update call + mock_config_repository.update.assert_called() + # Get the call arguments + call_args = mock_config_repository.update.call_args_list + # Find the call that includes recursion_limit with value 50 + recursion_limit_found = False + for args, _ in call_args: + config_dict = args[0] + if "recursion_limit" in config_dict and config_dict["recursion_limit"] == 50: + recursion_limit_found = True + break + assert recursion_limit_found, f"recursion_limit=50 not found in update calls: {call_args}" def test_negative_recursion_limit(): @@ -141,70 +202,83 @@ def test_zero_recursion_limit(): parse_arguments(["-m", "test message", "--recursion-limit", "0"]) -def test_config_settings(mock_dependencies): +def test_config_settings(mock_dependencies, mock_config_repository): """Test that various settings are correctly applied in global config.""" import sys from unittest.mock import patch from ra_aid.__main__ import main - - _global_memory.clear() - - with patch.object( - sys, - "argv", - [ - "ra-aid", - "-m", - "test message", - "--cowboy-mode", - "--research-only", - "--provider", - "anthropic", - "--model", - "claude-3-7-sonnet-20250219", - "--expert-provider", - "openai", - "--expert-model", - "gpt-4", - "--temperature", - "0.7", - "--disable-limit-tokens", - ], - ): - main() - config = _global_memory["config"] - assert config["cowboy_mode"] is True - assert config["research_only"] is True - assert config["provider"] == "anthropic" - assert config["model"] == "claude-3-7-sonnet-20250219" - assert config["expert_provider"] == "openai" - assert config["expert_model"] == "gpt-4" - assert config["limit_tokens"] is False + + # Clear the mock repository before each test + mock_config_repository.update.reset_mock() + + # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock + with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository): + with patch.object( + sys, + "argv", + [ + "ra-aid", + "-m", + "test message", + "--cowboy-mode", + "--research-only", + "--provider", + "anthropic", + "--model", + "claude-3-7-sonnet-20250219", + "--expert-provider", + "openai", + "--expert-model", + "gpt-4", + "--temperature", + "0.7", + "--disable-limit-tokens", + ], + ): + main() + # Verify config values are set via the update method + mock_config_repository.update.assert_called() + # Get the call arguments + call_args = mock_config_repository.update.call_args_list + + # Check for config values in the update calls + for args, _ in call_args: + config_dict = args[0] + if "cowboy_mode" in config_dict: + assert config_dict["cowboy_mode"] is True + if "research_only" in config_dict: + assert config_dict["research_only"] is True + if "limit_tokens" in config_dict: + assert config_dict["limit_tokens"] is False + + # Check provider and model settings via set method + mock_config_repository.set.assert_any_call("provider", "anthropic") + mock_config_repository.set.assert_any_call("model", "claude-3-7-sonnet-20250219") + mock_config_repository.set.assert_any_call("expert_provider", "openai") + mock_config_repository.set.assert_any_call("expert_model", "gpt-4") -def test_temperature_validation(mock_dependencies): +def test_temperature_validation(mock_dependencies, mock_config_repository): """Test that temperature argument is correctly passed to initialize_llm.""" import sys from unittest.mock import patch, ANY from ra_aid.__main__ import main - # Reset global memory for clean test - _global_memory.clear() - _global_memory["config"] = {} - - # Test valid temperature (0.7) - with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm: - # Also patch any calls that would actually use the mocked initialize_llm function - with patch("ra_aid.__main__.run_research_agent", return_value=None): - with patch("ra_aid.__main__.run_planning_agent", return_value=None): - with patch.object( - sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"] - ): - main() - # Check if temperature was stored in config correctly - assert _global_memory["config"]["temperature"] == 0.7 + # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock + with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository): + # Test valid temperature (0.7) + with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm: + # Also patch any calls that would actually use the mocked initialize_llm function + with patch("ra_aid.__main__.run_research_agent", return_value=None): + with patch("ra_aid.__main__.run_planning_agent", return_value=None): + with patch.object( + sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"] + ): + main() + # Verify that the temperature was set in the config repository + mock_config_repository.set.assert_any_call("temperature", 0.7) # Test invalid temperature (2.1) with pytest.raises(SystemExit): @@ -230,61 +304,67 @@ def test_missing_message(): assert args.message == "test" -def test_research_model_provider_args(mock_dependencies): +def test_research_model_provider_args(mock_dependencies, mock_config_repository): """Test that research-specific model/provider args are correctly stored in config.""" import sys from unittest.mock import patch from ra_aid.__main__ import main - _global_memory.clear() - - with patch.object( - sys, - "argv", - [ - "ra-aid", - "-m", - "test message", - "--research-provider", - "anthropic", - "--research-model", - "claude-3-haiku-20240307", - "--planner-provider", - "openai", - "--planner-model", - "gpt-4", - ], - ): - main() - config = _global_memory["config"] - assert config["research_provider"] == "anthropic" - assert config["research_model"] == "claude-3-haiku-20240307" - assert config["planner_provider"] == "openai" - assert config["planner_model"] == "gpt-4" + # Reset mocks + mock_config_repository.set.reset_mock() + + # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock + with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository): + with patch.object( + sys, + "argv", + [ + "ra-aid", + "-m", + "test message", + "--research-provider", + "anthropic", + "--research-model", + "claude-3-haiku-20240307", + "--planner-provider", + "openai", + "--planner-model", + "gpt-4", + ], + ): + main() + # Verify the mock repo's set method was called with the expected values + mock_config_repository.set.assert_any_call("research_provider", "anthropic") + mock_config_repository.set.assert_any_call("research_model", "claude-3-haiku-20240307") + mock_config_repository.set.assert_any_call("planner_provider", "openai") + mock_config_repository.set.assert_any_call("planner_model", "gpt-4") -def test_planner_model_provider_args(mock_dependencies): +def test_planner_model_provider_args(mock_dependencies, mock_config_repository): """Test that planner provider/model args fall back to main config when not specified.""" import sys from unittest.mock import patch from ra_aid.__main__ import main - _global_memory.clear() - - with patch.object( - sys, - "argv", - ["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"], - ): - main() - config = _global_memory["config"] - assert config["planner_provider"] == "openai" - assert config["planner_model"] == "gpt-4" + # Reset mocks + mock_config_repository.set.reset_mock() + + # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock + with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository): + with patch.object( + sys, + "argv", + ["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"], + ): + main() + # Verify the mock repo's set method was called with the expected values + mock_config_repository.set.assert_any_call("planner_provider", "openai") + mock_config_repository.set.assert_any_call("planner_model", "gpt-4") -def test_use_aider_flag(mock_dependencies): +def test_use_aider_flag(mock_dependencies, mock_config_repository): """Test that use-aider flag is correctly stored in config.""" import sys from unittest.mock import patch @@ -292,44 +372,68 @@ def test_use_aider_flag(mock_dependencies): from ra_aid.__main__ import main from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools - _global_memory.clear() - + # Reset mocks + mock_config_repository.update.reset_mock() + # Reset to default state set_modification_tools(False) - # Check default behavior (use_aider=False) - with patch.object( - sys, - "argv", - ["ra-aid", "-m", "test message"], - ): - main() - config = _global_memory["config"] - assert config.get("use_aider") is False + # For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock + with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository): + # Check default behavior (use_aider=False) + with patch.object( + sys, + "argv", + ["ra-aid", "-m", "test message"], + ): + main() + # Verify use_aider is set to False in the update call + mock_config_repository.update.assert_called() + # Get the call arguments + call_args = mock_config_repository.update.call_args_list + # Find the call that includes use_aider + use_aider_found = False + for args, _ in call_args: + config_dict = args[0] + if "use_aider" in config_dict and config_dict["use_aider"] is False: + use_aider_found = True + break + assert use_aider_found, f"use_aider=False not found in update calls: {call_args}" - # Check that file tools are enabled by default - tool_names = [tool.name for tool in MODIFICATION_TOOLS] - assert "file_str_replace" in tool_names - assert "put_complete_file_contents" in tool_names - assert "run_programming_task" not in tool_names + # Check that file tools are enabled by default + tool_names = [tool.name for tool in MODIFICATION_TOOLS] + assert "file_str_replace" in tool_names + assert "put_complete_file_contents" in tool_names + assert "run_programming_task" not in tool_names - _global_memory.clear() + # Reset mocks + mock_config_repository.update.reset_mock() - # Check with --use-aider flag - with patch.object( - sys, - "argv", - ["ra-aid", "-m", "test message", "--use-aider"], - ): - main() - config = _global_memory["config"] - assert config.get("use_aider") is True + # Check with --use-aider flag + with patch.object( + sys, + "argv", + ["ra-aid", "-m", "test message", "--use-aider"], + ): + main() + # Verify use_aider is set to True in the update call + mock_config_repository.update.assert_called() + # Get the call arguments + call_args = mock_config_repository.update.call_args_list + # Find the call that includes use_aider + use_aider_found = False + for args, _ in call_args: + config_dict = args[0] + if "use_aider" in config_dict and config_dict["use_aider"] is True: + use_aider_found = True + break + assert use_aider_found, f"use_aider=True not found in update calls: {call_args}" - # Check that run_programming_task is enabled - tool_names = [tool.name for tool in MODIFICATION_TOOLS] - assert "file_str_replace" not in tool_names - assert "put_complete_file_contents" not in tool_names - assert "run_programming_task" in tool_names + # Check that run_programming_task is enabled + tool_names = [tool.name for tool in MODIFICATION_TOOLS] + assert "file_str_replace" not in tool_names + assert "put_complete_file_contents" not in tool_names + assert "run_programming_task" in tool_names # Reset to default state for other tests set_modification_tools(False) \ No newline at end of file diff --git a/tests/ra_aid/test_programmer.py b/tests/ra_aid/test_programmer.py index cf24f00..82e9a36 100644 --- a/tests/ra_aid/test_programmer.py +++ b/tests/ra_aid/test_programmer.py @@ -1,3 +1,4 @@ +import os import pytest from unittest.mock import patch, MagicMock @@ -7,6 +8,36 @@ from ra_aid.tools.programmer import ( run_programming_task, ) from ra_aid.database.repositories.related_files_repository import get_related_files_repository +from ra_aid.database.repositories.config_repository import get_config_repository + +@pytest.fixture(autouse=True) +def mock_config_repository(): + """Mock the ConfigRepository to avoid database operations during tests""" + with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var: + # Setup a mock repository + mock_repo = MagicMock() + + # Create a dictionary to simulate config + config = { + "recursion_limit": 2, + "provider": "anthropic", + "model": "claude-3-5-sonnet-20241022", + "temperature": 0.01, + "aider_config": "/path/to/config.yml" + } + + # Setup get_all method to return the config dict + mock_repo.get_all.return_value = config + + # Setup get method to return config values + def get_config(key, default=None): + return config.get(key, default) + mock_repo.get.side_effect = get_config + + # Make the mock context var return our mock repo + mock_repo_var.get.return_value = mock_repo + + yield mock_repo @pytest.fixture(autouse=True) def mock_related_files_repository(): @@ -125,13 +156,9 @@ def test_parse_aider_flags(input_flags, expected, description): assert result == expected, f"Failed test case: {description}" -def test_aider_config_flag(mocker, mock_related_files_repository): +def test_aider_config_flag(mocker, mock_config_repository, mock_related_files_repository): """Test that aider config flag is properly included in the command when specified.""" - # Mock config in global memory but not related files (using repository now) - mock_memory = { - "config": {"aider_config": "/path/to/config.yml"}, - } - mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory) + # Config is mocked by mock_config_repository fixture # Mock the run_interactive_command to capture the command that would be run mock_run = mocker.patch( @@ -146,15 +173,14 @@ def test_aider_config_flag(mocker, mock_related_files_repository): assert args[config_index + 1] == "/path/to/config.yml" -def test_path_normalization_and_deduplication(mocker, tmp_path, mock_related_files_repository): +def test_path_normalization_and_deduplication(mocker, tmp_path, mock_config_repository, mock_related_files_repository): """Test path normalization and deduplication in run_programming_task.""" # Create a temporary test file test_file = tmp_path / "test.py" test_file.write_text("") new_file = tmp_path / "new.py" - # Mock dependencies - only need to mock config part of global memory now - mocker.patch("ra_aid.tools.programmer._global_memory", {"config": {}}) + # Config is mocked by mock_config_repository fixture mocker.patch( "ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider" ) diff --git a/tests/ra_aid/tools/test_agent.py b/tests/ra_aid/tools/test_agent.py index fdb36df..e371dfc 100644 --- a/tests/ra_aid/tools/test_agent.py +++ b/tests/ra_aid/tools/test_agent.py @@ -13,6 +13,7 @@ from ra_aid.tools.agent import ( from ra_aid.tools.memory import _global_memory from ra_aid.database.repositories.related_files_repository import get_related_files_repository from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry +from ra_aid.database.repositories.config_repository import get_config_repository @pytest.fixture @@ -43,6 +44,34 @@ def mock_related_files_repository(): yield mock_repo +@pytest.fixture(autouse=True) +def mock_config_repository(): + """Mock the ConfigRepository to avoid database operations during tests""" + with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var: + # Setup a mock repository + mock_repo = MagicMock() + + # Create a dictionary to simulate config + config = { + "recursion_limit": 2, + "provider": "anthropic", + "model": "claude-3-5-sonnet-20241022", + "temperature": 0.01 + } + + # Setup get_all method to return the config dict + mock_repo.get_all.return_value = config + + # Setup get method to return config values + def get_config(key, default=None): + return config.get(key, default) + mock_repo.get.side_effect = get_config + + # Make the mock context var return our mock repo + mock_repo_var.get.return_value = mock_repo + + yield mock_repo + @pytest.fixture(autouse=True) def mock_work_log_repository(): """Mock the WorkLogRepository to avoid database operations during tests""" diff --git a/tests/ra_aid/tools/test_shell.py b/tests/ra_aid/tools/test_shell.py index 3636ae6..6d2e410 100644 --- a/tests/ra_aid/tools/test_shell.py +++ b/tests/ra_aid/tools/test_shell.py @@ -1,8 +1,8 @@ -from unittest.mock import patch +from unittest.mock import patch, MagicMock import pytest -from ra_aid.tools.memory import _global_memory +from ra_aid.database.repositories.config_repository import ConfigRepositoryManager from ra_aid.tools.shell import run_shell_command @@ -25,9 +25,38 @@ def mock_run_interactive(): yield mock -def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive): +@pytest.fixture(autouse=True) +def mock_config_repository(): + """Mock the ConfigRepository to avoid database operations during tests""" + with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var: + # Setup a mock repository + mock_repo = MagicMock() + + # Create a dictionary to simulate config + config = { + "cowboy_mode": False + } + + # Setup get method to return config values + def get_config(key, default=None): + return config.get(key, default) + mock_repo.get.side_effect = get_config + + # Setup set method to update config values + def set_config(key, value): + config[key] = value + mock_repo.set.side_effect = set_config + + # Make the mock context var return our mock repo + mock_repo_var.get.return_value = mock_repo + + yield mock_repo + + +def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository): """Test shell command execution in cowboy mode (no approval)""" - _global_memory["config"] = {"cowboy_mode": True} + # Set cowboy mode to True using the repository + mock_config_repository.set("cowboy_mode", True) result = run_shell_command.invoke({"command": "echo test"}) @@ -37,9 +66,10 @@ def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interacti mock_prompt.ask.assert_not_called() -def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive): +def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive, mock_config_repository): """Test that cowboy mode displays a properly formatted cowboy message with correct spacing""" - _global_memory["config"] = {"cowboy_mode": True} + # Set cowboy mode to True using the repository + mock_config_repository.set("cowboy_mode", True) with patch("ra_aid.tools.shell.get_cowboy_message") as mock_get_message: mock_get_message.return_value = "🤠 Test cowboy message!" @@ -53,10 +83,11 @@ def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_intera def test_shell_command_interactive_approved( - mock_console, mock_prompt, mock_run_interactive + mock_console, mock_prompt, mock_run_interactive, mock_config_repository ): """Test shell command execution with interactive approval""" - _global_memory["config"] = {"cowboy_mode": False} + # Set cowboy mode to False using the repository + mock_config_repository.set("cowboy_mode", False) mock_prompt.ask.return_value = "y" result = run_shell_command.invoke({"command": "echo test"}) @@ -74,10 +105,11 @@ def test_shell_command_interactive_approved( def test_shell_command_interactive_rejected( - mock_console, mock_prompt, mock_run_interactive + mock_console, mock_prompt, mock_run_interactive, mock_config_repository ): """Test shell command rejection in interactive mode""" - _global_memory["config"] = {"cowboy_mode": False} + # Set cowboy mode to False using the repository + mock_config_repository.set("cowboy_mode", False) mock_prompt.ask.return_value = "n" result = run_shell_command.invoke({"command": "echo test"}) @@ -95,13 +127,14 @@ def test_shell_command_interactive_rejected( mock_run_interactive.assert_not_called() -def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive): +def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive, mock_config_repository): """Test handling of shell command execution errors""" - _global_memory["config"] = {"cowboy_mode": True} + # Set cowboy mode to True using the repository + mock_config_repository.set("cowboy_mode", True) mock_run_interactive.side_effect = Exception("Command failed") result = run_shell_command.invoke({"command": "invalid command"}) assert result["success"] is False assert result["return_code"] == 1 - assert "Command failed" in result["output"] + assert "Command failed" in result["output"] \ No newline at end of file