config repository
This commit is contained in:
parent
3e68dd3fa6
commit
5bd8c76a22
|
|
@ -58,6 +58,10 @@ from ra_aid.database.repositories.related_files_repository import (
|
|||
from ra_aid.database.repositories.work_log_repository import (
|
||||
WorkLogRepositoryManager
|
||||
)
|
||||
from ra_aid.database.repositories.config_repository import (
|
||||
ConfigRepositoryManager,
|
||||
get_config_repository
|
||||
)
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ra_aid.console.output import cpm
|
||||
|
|
@ -77,7 +81,7 @@ from ra_aid.prompts.chat_prompts import CHAT_PROMPT
|
|||
from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
from ra_aid.tool_configs import get_chat_tools, set_modification_tools
|
||||
from ra_aid.tools.human import ask_human
|
||||
from ra_aid.tools.memory import _global_memory, get_memory_value
|
||||
from ra_aid.tools.memory import get_memory_value
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -338,7 +342,7 @@ implementation_memory = MemorySaver()
|
|||
|
||||
def is_informational_query() -> bool:
|
||||
"""Determine if the current query is informational based on config settings."""
|
||||
return _global_memory.get("config", {}).get("research_only", False)
|
||||
return get_config_repository().get("research_only", False)
|
||||
|
||||
|
||||
def is_stage_requested(stage: str) -> bool:
|
||||
|
|
@ -404,13 +408,17 @@ def main():
|
|||
except Exception as e:
|
||||
logger.error(f"Database migration error: {str(e)}")
|
||||
|
||||
# Initialize empty config dictionary to be populated later
|
||||
config = {}
|
||||
|
||||
# Initialize repositories with database connection
|
||||
with KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||
WorkLogRepositoryManager() as work_log_repo:
|
||||
WorkLogRepositoryManager() as work_log_repo, \
|
||||
ConfigRepositoryManager(config) as config_repo:
|
||||
# This initializes all repositories and makes them available via their respective get methods
|
||||
logger.debug("Initialized KeyFactRepository")
|
||||
logger.debug("Initialized KeySnippetRepository")
|
||||
|
|
@ -418,6 +426,7 @@ def main():
|
|||
logger.debug("Initialized ResearchNoteRepository")
|
||||
logger.debug("Initialized RelatedFilesRepository")
|
||||
logger.debug("Initialized WorkLogRepository")
|
||||
logger.debug("Initialized ConfigRepository")
|
||||
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
|
|
@ -520,13 +529,13 @@ def main():
|
|||
"limit_tokens": args.disable_limit_tokens,
|
||||
}
|
||||
|
||||
# Store config in global memory
|
||||
_global_memory["config"] = config
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
# Store config in repository
|
||||
config_repo.update(config)
|
||||
config_repo.set("provider", args.provider)
|
||||
config_repo.set("model", args.model)
|
||||
config_repo.set("expert_provider", args.expert_provider)
|
||||
config_repo.set("expert_model", args.expert_model)
|
||||
config_repo.set("temperature", args.temperature)
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
|
@ -594,33 +603,27 @@ def main():
|
|||
"test_cmd_timeout": args.test_cmd_timeout,
|
||||
}
|
||||
|
||||
# Store config in global memory for access by is_informational_query
|
||||
_global_memory["config"] = config
|
||||
# Store config in repository
|
||||
config_repo.update(config)
|
||||
|
||||
# Store base provider/model configuration
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
config_repo.set("provider", args.provider)
|
||||
config_repo.set("model", args.model)
|
||||
|
||||
# Store expert provider/model (no fallback)
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
config_repo.set("expert_provider", args.expert_provider)
|
||||
config_repo.set("expert_model", args.expert_model)
|
||||
|
||||
# Store planner config with fallback to base values
|
||||
_global_memory["config"]["planner_provider"] = (
|
||||
args.planner_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||
config_repo.set("planner_provider", args.planner_provider or args.provider)
|
||||
config_repo.set("planner_model", args.planner_model or args.model)
|
||||
|
||||
# Store research config with fallback to base values
|
||||
_global_memory["config"]["research_provider"] = (
|
||||
args.research_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["research_model"] = (
|
||||
args.research_model or args.model
|
||||
)
|
||||
config_repo.set("research_provider", args.research_provider or args.provider)
|
||||
config_repo.set("research_model", args.research_model or args.model)
|
||||
|
||||
# Store temperature in global config
|
||||
_global_memory["config"]["temperature"] = args.temperature
|
||||
# Store temperature in config
|
||||
config_repo.set("temperature", args.temperature)
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
|
|
|||
|
|
@ -94,11 +94,11 @@ from ra_aid.model_formatters import format_key_facts_dict
|
|||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||
from ra_aid.tools.memory import (
|
||||
_global_memory,
|
||||
get_memory_value,
|
||||
get_related_files,
|
||||
log_work_event,
|
||||
)
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -302,7 +302,7 @@ def create_agent(
|
|||
config['limit_tokens'] = False.
|
||||
"""
|
||||
try:
|
||||
config = _global_memory.get("config", {})
|
||||
config = get_config_repository().get_all()
|
||||
max_input_tokens = (
|
||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||
)
|
||||
|
|
@ -319,7 +319,7 @@ def create_agent(
|
|||
except Exception as e:
|
||||
# Default to REACT agent if provider/model detection fails
|
||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
config = _global_memory.get("config", {})
|
||||
config = get_config_repository().get_all()
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
|
||||
return create_react_agent(model, tools, **agent_kwargs)
|
||||
|
|
@ -443,7 +443,7 @@ def run_research_agent(
|
|||
new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "",
|
||||
)
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
config = get_config_repository().get_all() if not config else config
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
|
|
@ -575,7 +575,7 @@ def run_web_research_agent(
|
|||
related_files=related_files,
|
||||
)
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
config = get_config_repository().get_all() if not config else config
|
||||
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {
|
||||
|
|
@ -709,7 +709,7 @@ def run_planning_agent(
|
|||
),
|
||||
)
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
config = get_config_repository().get_all() if not config else config
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
|
|
@ -824,7 +824,7 @@ def run_task_implementation_agent(
|
|||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||
human_section=(
|
||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
||||
if _global_memory.get("config", {}).get("hil", False)
|
||||
if get_config_repository().get("hil", False)
|
||||
else ""
|
||||
),
|
||||
web_research_section=(
|
||||
|
|
@ -834,7 +834,7 @@ def run_task_implementation_agent(
|
|||
),
|
||||
)
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
config = get_config_repository().get_all() if not config else config
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ logger = logging.getLogger(__name__)
|
|||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event, _global_memory
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
||||
|
||||
console = Console()
|
||||
|
|
@ -149,7 +150,7 @@ def run_key_facts_gc_agent() -> None:
|
|||
formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()])
|
||||
|
||||
# Retrieve configuration
|
||||
llm_config = _global_memory.get("config", {})
|
||||
llm_config = get_config_repository().get_all()
|
||||
|
||||
# Initialize the LLM model
|
||||
model = initialize_llm(
|
||||
|
|
|
|||
|
|
@ -16,9 +16,10 @@ from rich.panel import Panel
|
|||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event, _global_memory
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
||||
|
||||
console = Console()
|
||||
|
|
@ -153,7 +154,7 @@ def run_key_snippets_gc_agent() -> None:
|
|||
])
|
||||
|
||||
# Retrieve configuration
|
||||
llm_config = _global_memory.get("config", {})
|
||||
llm_config = get_config_repository().get_all()
|
||||
|
||||
# Initialize the LLM model
|
||||
model = initialize_llm(
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ logger = logging.getLogger(__name__)
|
|||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
||||
from ra_aid.tools.memory import log_work_event, _global_memory
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
||||
|
||||
console = Console()
|
||||
|
|
@ -154,7 +155,7 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None:
|
|||
formatted_notes = "\n".join([f"Note #{k}: {v}" for k, v in notes_dict.items()])
|
||||
|
||||
# Retrieve configuration
|
||||
llm_config = _global_memory.get("config", {})
|
||||
llm_config = get_config_repository().get_all()
|
||||
|
||||
# Initialize the LLM model
|
||||
model = initialize_llm(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,165 @@
|
|||
"""Repository for managing configuration values."""
|
||||
|
||||
import contextvars
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# Create contextvar to hold the ConfigRepository instance
|
||||
config_repo_var = contextvars.ContextVar("config_repo", default=None)
|
||||
|
||||
|
||||
class ConfigRepository:
|
||||
"""
|
||||
Repository for managing configuration values in memory.
|
||||
|
||||
This class provides methods to get, set, update, and retrieve all configuration values.
|
||||
It does not require database models and operates entirely in memory.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialize the ConfigRepository.
|
||||
|
||||
Args:
|
||||
initial_config: Optional dictionary of initial configuration values
|
||||
"""
|
||||
self._config: Dict[str, Any] = {}
|
||||
|
||||
# Initialize with default values from config.py
|
||||
from ra_aid.config import (
|
||||
DEFAULT_RECURSION_LIMIT,
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||
DEFAULT_MAX_TOOL_FAILURES,
|
||||
FALLBACK_TOOL_MODEL_LIMIT,
|
||||
RETRY_FALLBACK_COUNT,
|
||||
DEFAULT_TEST_CMD_TIMEOUT,
|
||||
VALID_PROVIDERS,
|
||||
)
|
||||
|
||||
self._config = {
|
||||
"recursion_limit": DEFAULT_RECURSION_LIMIT,
|
||||
"max_test_cmd_retries": DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||
"max_tool_failures": DEFAULT_MAX_TOOL_FAILURES,
|
||||
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
|
||||
"retry_fallback_count": RETRY_FALLBACK_COUNT,
|
||||
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
|
||||
"valid_providers": VALID_PROVIDERS,
|
||||
}
|
||||
|
||||
# Update with any provided initial configuration
|
||||
if initial_config:
|
||||
self._config.update(initial_config)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get a configuration value by key.
|
||||
|
||||
Args:
|
||||
key: Configuration key to retrieve
|
||||
default: Default value to return if key is not found
|
||||
|
||||
Returns:
|
||||
The configuration value or default if not found
|
||||
"""
|
||||
return self._config.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""
|
||||
Set a configuration value by key.
|
||||
|
||||
Args:
|
||||
key: Configuration key to set
|
||||
value: Value to set for the key
|
||||
"""
|
||||
self._config[key] = value
|
||||
|
||||
def update(self, config_dict: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Update multiple configuration values at once.
|
||||
|
||||
Args:
|
||||
config_dict: Dictionary of configuration key-value pairs to update
|
||||
"""
|
||||
self._config.update(config_dict)
|
||||
|
||||
def get_all(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all configuration values.
|
||||
|
||||
Returns:
|
||||
Dictionary containing all configuration values
|
||||
"""
|
||||
return self._config.copy()
|
||||
|
||||
|
||||
class ConfigRepositoryManager:
|
||||
"""
|
||||
Context manager for ConfigRepository.
|
||||
|
||||
This class provides a context manager interface for ConfigRepository,
|
||||
using the contextvars approach for thread safety.
|
||||
|
||||
Example:
|
||||
with ConfigRepositoryManager() as repo:
|
||||
# Use the repository
|
||||
value = repo.get("key")
|
||||
repo.set("key", new_value)
|
||||
"""
|
||||
|
||||
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialize the ConfigRepositoryManager.
|
||||
|
||||
Args:
|
||||
initial_config: Optional dictionary of initial configuration values
|
||||
"""
|
||||
self.initial_config = initial_config
|
||||
|
||||
def __enter__(self) -> 'ConfigRepository':
|
||||
"""
|
||||
Initialize the ConfigRepository and return it.
|
||||
|
||||
Returns:
|
||||
ConfigRepository: The initialized repository
|
||||
"""
|
||||
repo = ConfigRepository(self.initial_config)
|
||||
config_repo_var.set(repo)
|
||||
return repo
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type],
|
||||
exc_val: Optional[Exception],
|
||||
exc_tb: Optional[object],
|
||||
) -> None:
|
||||
"""
|
||||
Reset the repository when exiting the context.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised
|
||||
exc_val: The exception value if an exception was raised
|
||||
exc_tb: The traceback if an exception was raised
|
||||
"""
|
||||
# Reset the contextvar to None
|
||||
config_repo_var.set(None)
|
||||
|
||||
# Don't suppress exceptions
|
||||
return False
|
||||
|
||||
|
||||
def get_config_repository() -> ConfigRepository:
|
||||
"""
|
||||
Get the current ConfigRepository instance.
|
||||
|
||||
Returns:
|
||||
ConfigRepository: The current repository instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no repository is set in the current context
|
||||
"""
|
||||
repo = config_repo_var.get()
|
||||
if repo is None:
|
||||
raise RuntimeError(
|
||||
"ConfigRepository not initialized in current context. "
|
||||
"Make sure to use ConfigRepositoryManager."
|
||||
)
|
||||
return repo
|
||||
|
|
@ -27,6 +27,7 @@ from ra_aid.tools.agent import (
|
|||
request_web_research,
|
||||
)
|
||||
from ra_aid.tools.memory import plan_implementation_completed
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
||||
|
||||
def set_modification_tools(use_aider=False):
|
||||
|
|
@ -98,13 +99,11 @@ def get_all_tools() -> list[BaseTool]:
|
|||
|
||||
|
||||
# Define constant tool groups
|
||||
# Get config from global memory for use_aider value
|
||||
# Get config from repository for use_aider value
|
||||
_config = {}
|
||||
try:
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
_config = _global_memory.get("config", {})
|
||||
except ImportError:
|
||||
_config = get_config_repository().get_all()
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
READ_ONLY_TOOLS = get_read_only_tools(use_aider=_config.get("use_aider", False))
|
||||
|
|
@ -139,10 +138,8 @@ def get_research_tools(
|
|||
# Get config for use_aider value
|
||||
use_aider = False
|
||||
try:
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||
except ImportError:
|
||||
use_aider = get_config_repository().get("use_aider", False)
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
# Start with read-only tools
|
||||
|
|
@ -180,10 +177,8 @@ def get_planning_tools(
|
|||
# Get config for use_aider value
|
||||
use_aider = False
|
||||
try:
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||
except ImportError:
|
||||
use_aider = get_config_repository().get("use_aider", False)
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
# Start with read-only tools
|
||||
|
|
@ -219,10 +214,8 @@ def get_implementation_tools(
|
|||
# Get config for use_aider value
|
||||
use_aider = False
|
||||
try:
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
||||
except ImportError:
|
||||
use_aider = get_config_repository().get("use_aider", False)
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
# Start with read-only tools
|
||||
|
|
|
|||
|
|
@ -18,13 +18,13 @@ from ra_aid.console.formatting import print_error
|
|||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
|
||||
from ..console import print_task_header
|
||||
from ..llm import initialize_llm
|
||||
|
|
@ -52,7 +52,7 @@ def request_research(query: str) -> ResearchResult:
|
|||
query: The research question or project description
|
||||
"""
|
||||
# Initialize model from config
|
||||
config = _global_memory.get("config", {})
|
||||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-7-sonnet-20250219"),
|
||||
|
|
@ -165,7 +165,7 @@ def request_web_research(query: str) -> ResearchResult:
|
|||
query: The research question or project description
|
||||
"""
|
||||
# Initialize model from config
|
||||
config = _global_memory.get("config", {})
|
||||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-7-sonnet-20250219"),
|
||||
|
|
@ -246,7 +246,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
|||
query: The research question or project description
|
||||
"""
|
||||
# Initialize model from config
|
||||
config = _global_memory.get("config", {})
|
||||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-7-sonnet-20250219"),
|
||||
|
|
@ -335,7 +335,7 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
task_spec: REQUIRED The full task specification (markdown format, typically one part of the overall plan)
|
||||
"""
|
||||
# Initialize model from config
|
||||
config = _global_memory.get("config", {})
|
||||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
||||
|
|
@ -474,7 +474,7 @@ def request_implementation(task_spec: str) -> str:
|
|||
task_spec: The task specification to plan implementation for
|
||||
"""
|
||||
# Initialize model from config
|
||||
config = _global_memory.get("config", {})
|
||||
config = get_config_repository().get_all()
|
||||
model = initialize_llm(
|
||||
config.get("provider", "anthropic"),
|
||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
||||
|
|
|
|||
|
|
@ -13,11 +13,12 @@ from ..database.repositories.key_fact_repository import get_key_fact_repository
|
|||
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ..database.repositories.related_files_repository import get_related_files_repository
|
||||
from ..database.repositories.research_note_repository import get_research_note_repository
|
||||
from ..database.repositories.config_repository import get_config_repository
|
||||
from ..llm import initialize_expert_llm
|
||||
from ..model_formatters import format_key_facts_dict
|
||||
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ..model_formatters.research_notes_formatter import format_research_notes_dict
|
||||
from .memory import _global_memory, get_memory_value
|
||||
from .memory import get_memory_value
|
||||
|
||||
console = Console()
|
||||
_model = None
|
||||
|
|
@ -27,9 +28,9 @@ def get_model():
|
|||
global _model
|
||||
try:
|
||||
if _model is None:
|
||||
config = _global_memory["config"]
|
||||
provider = config.get("expert_provider") or config.get("provider")
|
||||
model = config.get("expert_model") or config.get("model")
|
||||
config_repo = get_config_repository()
|
||||
provider = config_repo.get("expert_provider") or config_repo.get("provider")
|
||||
model = config_repo.get("expert_model") or config_repo.get("model")
|
||||
_model = initialize_expert_llm(provider, model)
|
||||
except Exception as e:
|
||||
_model = None
|
||||
|
|
|
|||
|
|
@ -58,10 +58,10 @@ def ask_human(question: str) -> str:
|
|||
# Record human response in database
|
||||
try:
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
||||
# Determine the source based on context
|
||||
config = _global_memory.get("config", {})
|
||||
config = get_config_repository().get_all()
|
||||
# If chat_mode is enabled, use 'chat', otherwise determine if hil mode is active
|
||||
if config.get("chat_mode", False):
|
||||
source = "chat"
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ from ra_aid.logging_config import get_logger
|
|||
from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
|
||||
from ra_aid.proc.interactive import run_interactive_command
|
||||
from ra_aid.text.processing import truncate_output
|
||||
from ra_aid.tools.memory import _global_memory, log_work_event
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||
|
||||
console = Console()
|
||||
|
|
@ -107,8 +108,9 @@ def run_programming_task(
|
|||
)
|
||||
|
||||
# Add config file if specified
|
||||
if "config" in _global_memory and _global_memory["config"].get("aider_config"):
|
||||
command.extend(["--config", _global_memory["config"]["aider_config"]])
|
||||
config = get_config_repository().get_all()
|
||||
if config.get("aider_config"):
|
||||
command.extend(["--config", config["aider_config"]])
|
||||
|
||||
# if environment variable AIDER_FLAGS exists then parse
|
||||
if "AIDER_FLAGS" in os.environ:
|
||||
|
|
@ -147,8 +149,9 @@ def run_programming_task(
|
|||
# Run the command interactively
|
||||
print()
|
||||
# Get provider/model specific latency coefficient
|
||||
provider = _global_memory.get("config", {}).get("provider", "")
|
||||
model = _global_memory.get("config", {}).get("model", "")
|
||||
config = get_config_repository().get_all()
|
||||
provider = config.get("provider", "")
|
||||
model = config.get("model", "")
|
||||
latency = (
|
||||
models_params.get(provider, {})
|
||||
.get(model, {})
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from ra_aid.console.cowboy_messages import get_cowboy_message
|
|||
from ra_aid.proc.interactive import run_interactive_command
|
||||
from ra_aid.text.processing import truncate_output
|
||||
from ra_aid.tools.memory import _global_memory, log_work_event
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -46,7 +47,7 @@ def run_shell_command(
|
|||
4. Add flags e.g. git --no-pager in order to reduce interaction required by the human.
|
||||
"""
|
||||
# Check if we need approval
|
||||
cowboy_mode = _global_memory.get("config", {}).get("cowboy_mode", False)
|
||||
cowboy_mode = get_config_repository().get("cowboy_mode", False)
|
||||
|
||||
if cowboy_mode:
|
||||
console.print("")
|
||||
|
|
@ -74,7 +75,7 @@ def run_shell_command(
|
|||
"success": False,
|
||||
}
|
||||
elif response == "c":
|
||||
_global_memory["config"]["cowboy_mode"] = True
|
||||
get_config_repository().set("cowboy_mode", True)
|
||||
console.print("")
|
||||
console.print(" " + get_cowboy_message())
|
||||
console.print("")
|
||||
|
|
|
|||
|
|
@ -7,10 +7,25 @@ ensuring consistent test environments and proper isolation.
|
|||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_config_repository():
|
||||
"""Mock the config repository."""
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
||||
repo = MagicMock()
|
||||
# Default config values
|
||||
config_values = {"recursion_limit": 2}
|
||||
repo.get_all.return_value = config_values
|
||||
repo.get.side_effect = lambda key, default=None: config_values.get(key, default)
|
||||
get_config_repository.return_value = repo
|
||||
yield repo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db_environment(tmp_path, monkeypatch):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Unit tests for agent_utils.py."""
|
||||
|
||||
from typing import Any, Dict, Literal
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
import litellm
|
||||
import pytest
|
||||
|
|
@ -19,6 +19,7 @@ from ra_aid.agent_utils import (
|
|||
state_modifier,
|
||||
)
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -29,40 +30,70 @@ def mock_model():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory():
|
||||
"""Fixture providing a mock global memory store."""
|
||||
with patch("ra_aid.agent_utils._global_memory") as mock_mem:
|
||||
mock_mem.get.return_value = {}
|
||||
yield mock_mem
|
||||
def mock_config_repository():
|
||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Create a dictionary to simulate config
|
||||
config = {}
|
||||
|
||||
# Setup get method to return config values
|
||||
def get_config(key, default=None):
|
||||
return config.get(key, default)
|
||||
mock_repo.get.side_effect = get_config
|
||||
|
||||
# Setup get_all method to return all config values
|
||||
mock_repo.get_all.return_value = config
|
||||
|
||||
# Setup set method to update config values
|
||||
def set_config(key, value):
|
||||
config[key] = value
|
||||
mock_repo.set.side_effect = set_config
|
||||
|
||||
# Setup update method to update multiple config values
|
||||
def update_config(update_dict):
|
||||
config.update(update_dict)
|
||||
mock_repo.update.side_effect = update_config
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
def test_get_model_token_limit_anthropic(mock_memory):
|
||||
def test_get_model_token_limit_anthropic(mock_config_repository):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_openai(mock_memory):
|
||||
def test_get_model_token_limit_openai(mock_config_repository):
|
||||
"""Test get_model_token_limit with OpenAI model."""
|
||||
config = {"provider": "openai", "model": "gpt-4"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_unknown(mock_memory):
|
||||
def test_get_model_token_limit_unknown(mock_config_repository):
|
||||
"""Test get_model_token_limit with unknown provider/model."""
|
||||
config = {"provider": "unknown", "model": "unknown-model"}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_get_model_token_limit_missing_config(mock_memory):
|
||||
def test_get_model_token_limit_missing_config(mock_config_repository):
|
||||
"""Test get_model_token_limit with missing configuration."""
|
||||
config = {}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
token_limit = get_model_token_limit(config, "default")
|
||||
assert token_limit is None
|
||||
|
|
@ -108,9 +139,9 @@ def test_get_model_token_limit_unexpected_error():
|
|||
assert token_limit is None
|
||||
|
||||
|
||||
def test_create_agent_anthropic(mock_model, mock_memory):
|
||||
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||
"""Test create_agent with Anthropic Claude model."""
|
||||
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"}
|
||||
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
||||
|
||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
|
||||
mock_react.return_value = "react_agent"
|
||||
|
|
@ -125,9 +156,9 @@ def test_create_agent_anthropic(mock_model, mock_memory):
|
|||
)
|
||||
|
||||
|
||||
def test_create_agent_openai(mock_model, mock_memory):
|
||||
def test_create_agent_openai(mock_model, mock_config_repository):
|
||||
"""Test create_agent with OpenAI model."""
|
||||
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
|
||||
mock_config_repository.update({"provider": "openai", "model": "gpt-4"})
|
||||
|
||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||
mock_ciayn.return_value = "ciayn_agent"
|
||||
|
|
@ -142,9 +173,9 @@ def test_create_agent_openai(mock_model, mock_memory):
|
|||
)
|
||||
|
||||
|
||||
def test_create_agent_no_token_limit(mock_model, mock_memory):
|
||||
def test_create_agent_no_token_limit(mock_model, mock_config_repository):
|
||||
"""Test create_agent when no token limit is found."""
|
||||
mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"}
|
||||
mock_config_repository.update({"provider": "unknown", "model": "unknown-model"})
|
||||
|
||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||
mock_ciayn.return_value = "ciayn_agent"
|
||||
|
|
@ -159,9 +190,9 @@ def test_create_agent_no_token_limit(mock_model, mock_memory):
|
|||
)
|
||||
|
||||
|
||||
def test_create_agent_missing_config(mock_model, mock_memory):
|
||||
def test_create_agent_missing_config(mock_model, mock_config_repository):
|
||||
"""Test create_agent with missing configuration."""
|
||||
mock_memory.get.return_value = {"provider": "openai"}
|
||||
mock_config_repository.update({"provider": "openai"})
|
||||
|
||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||
mock_ciayn.return_value = "ciayn_agent"
|
||||
|
|
@ -205,9 +236,9 @@ def test_state_modifier(mock_messages):
|
|||
assert result[-1] == mock_messages[-1]
|
||||
|
||||
|
||||
def test_create_agent_with_checkpointer(mock_model, mock_memory):
|
||||
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
||||
"""Test create_agent with checkpointer argument."""
|
||||
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
|
||||
mock_config_repository.update({"provider": "openai", "model": "gpt-4"})
|
||||
mock_checkpointer = Mock()
|
||||
|
||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||
|
|
@ -223,13 +254,13 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory):
|
|||
)
|
||||
|
||||
|
||||
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
|
||||
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
|
||||
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
||||
mock_memory.get.return_value = {
|
||||
mock_config_repository.update({
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": True,
|
||||
}
|
||||
})
|
||||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
|
|
@ -246,13 +277,13 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
|
|||
assert callable(args[1]["state_modifier"])
|
||||
|
||||
|
||||
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory):
|
||||
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
|
||||
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
||||
mock_memory.get.return_value = {
|
||||
mock_config_repository.update({
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": False,
|
||||
}
|
||||
})
|
||||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
|
|
@ -267,7 +298,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
|
|||
mock_react.assert_called_once_with(mock_model, [], version="v2")
|
||||
|
||||
|
||||
def test_get_model_token_limit_research(mock_memory):
|
||||
def test_get_model_token_limit_research(mock_config_repository):
|
||||
"""Test get_model_token_limit with research provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
|
|
@ -275,13 +306,15 @@ def test_get_model_token_limit_research(mock_memory):
|
|||
"research_provider": "anthropic",
|
||||
"research_model": "claude-2",
|
||||
}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||
token_limit = get_model_token_limit(config, "research")
|
||||
assert token_limit == 150000
|
||||
|
||||
|
||||
def test_get_model_token_limit_planner(mock_memory):
|
||||
def test_get_model_token_limit_planner(mock_config_repository):
|
||||
"""Test get_model_token_limit with planner provider and model."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
|
|
@ -289,6 +322,8 @@ def test_get_model_token_limit_planner(mock_memory):
|
|||
"planner_provider": "deepseek",
|
||||
"planner_model": "dsm-1",
|
||||
}
|
||||
mock_config_repository.update(config)
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||
token_limit = get_model_token_limit(config, "planner")
|
||||
|
|
|
|||
|
|
@ -1,41 +1,35 @@
|
|||
"""Tests for the is_informational_query and is_stage_requested functions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.__main__ import is_informational_query, is_stage_requested
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
|
||||
|
||||
|
||||
def test_is_informational_query():
|
||||
@pytest.fixture
|
||||
def config_repo():
|
||||
"""Fixture for config repository."""
|
||||
with ConfigRepositoryManager() as repo:
|
||||
yield repo
|
||||
|
||||
|
||||
def test_is_informational_query(config_repo):
|
||||
"""Test that is_informational_query only depends on research_only config setting."""
|
||||
# Clear global memory to ensure clean state
|
||||
_global_memory.clear()
|
||||
|
||||
# When research_only is True, should return True
|
||||
_global_memory["config"] = {"research_only": True}
|
||||
config_repo.set("research_only", True)
|
||||
assert is_informational_query() is True
|
||||
|
||||
# When research_only is False, should return False
|
||||
_global_memory["config"] = {"research_only": False}
|
||||
config_repo.set("research_only", False)
|
||||
assert is_informational_query() is False
|
||||
|
||||
# When config is empty, should return False (default)
|
||||
_global_memory.clear()
|
||||
_global_memory["config"] = {}
|
||||
assert is_informational_query() is False
|
||||
|
||||
# When global memory is empty, should return False (default)
|
||||
_global_memory.clear()
|
||||
config_repo.update({})
|
||||
assert is_informational_query() is False
|
||||
|
||||
|
||||
def test_is_stage_requested():
|
||||
"""Test that is_stage_requested always returns False now."""
|
||||
# Clear global memory to ensure clean state
|
||||
_global_memory.clear()
|
||||
|
||||
# Should always return False regardless of input
|
||||
assert is_stage_requested("implementation") is False
|
||||
assert is_stage_requested("anything_else") is False
|
||||
|
||||
# Even if we set implementation_requested in global memory
|
||||
_global_memory["implementation_requested"] = True
|
||||
assert is_stage_requested("implementation") is False
|
||||
|
|
@ -7,14 +7,50 @@ from ra_aid.__main__ import parse_arguments
|
|||
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.database.repositories.work_log_repository import WorkLogEntry
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_config_repository():
|
||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Create a dictionary to simulate config
|
||||
config = {}
|
||||
|
||||
# Setup get method to return config values
|
||||
def get_config(key, default=None):
|
||||
return config.get(key, default)
|
||||
mock_repo.get.side_effect = get_config
|
||||
|
||||
# Setup set method to update config values
|
||||
def set_config(key, value):
|
||||
config[key] = value
|
||||
mock_repo.set.side_effect = set_config
|
||||
|
||||
# Setup update method to update multiple config values
|
||||
def update_config(config_dict):
|
||||
config.update(config_dict)
|
||||
mock_repo.update.side_effect = update_config
|
||||
|
||||
# Setup get_all method to return the config dict
|
||||
def get_all_config():
|
||||
return config.copy()
|
||||
mock_repo.get_all.side_effect = get_all_config
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(monkeypatch):
|
||||
"""Mock all dependencies needed for main()."""
|
||||
# Initialize global memory with necessary keys to prevent KeyError
|
||||
# Initialize global memory
|
||||
_global_memory.clear()
|
||||
_global_memory["config"] = {}
|
||||
|
||||
# Mock dependencies that interact with external systems
|
||||
monkeypatch.setattr("ra_aid.__main__.check_dependencies", lambda: None)
|
||||
|
|
@ -26,10 +62,9 @@ def mock_dependencies(monkeypatch):
|
|||
|
||||
# Mock LLM initialization
|
||||
def mock_config_update(*args, **kwargs):
|
||||
config = _global_memory.get("config", {})
|
||||
config_repo = get_config_repository()
|
||||
if kwargs.get("temperature"):
|
||||
config["temperature"] = kwargs["temperature"]
|
||||
_global_memory["config"] = config
|
||||
config_repo.set("temperature", kwargs["temperature"])
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update)
|
||||
|
|
@ -107,26 +142,52 @@ def mock_work_log_repository():
|
|||
yield mock_repo
|
||||
|
||||
|
||||
def test_recursion_limit_in_global_config(mock_dependencies):
|
||||
def test_recursion_limit_in_global_config(mock_dependencies, mock_config_repository):
|
||||
"""Test that recursion limit is correctly set in global config."""
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from ra_aid.__main__ import main
|
||||
|
||||
_global_memory.clear()
|
||||
# Clear the mock repository before each test
|
||||
mock_config_repository.update.reset_mock()
|
||||
|
||||
with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]):
|
||||
main()
|
||||
assert _global_memory["config"]["recursion_limit"] == DEFAULT_RECURSION_LIMIT
|
||||
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||
# Test default recursion limit
|
||||
with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]):
|
||||
main()
|
||||
# Check that the recursion_limit value was included in the update call
|
||||
mock_config_repository.update.assert_called()
|
||||
# Get the call arguments
|
||||
call_args = mock_config_repository.update.call_args_list
|
||||
# Find the call that includes recursion_limit
|
||||
recursion_limit_found = False
|
||||
for args, _ in call_args:
|
||||
config_dict = args[0]
|
||||
if "recursion_limit" in config_dict and config_dict["recursion_limit"] == DEFAULT_RECURSION_LIMIT:
|
||||
recursion_limit_found = True
|
||||
break
|
||||
assert recursion_limit_found, f"recursion_limit not found in update calls: {call_args}"
|
||||
|
||||
_global_memory.clear()
|
||||
# Reset mock to clear call history
|
||||
mock_config_repository.update.reset_mock()
|
||||
|
||||
with patch.object(
|
||||
sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"]
|
||||
):
|
||||
main()
|
||||
assert _global_memory["config"]["recursion_limit"] == 50
|
||||
# Test custom recursion limit
|
||||
with patch.object(sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"]):
|
||||
main()
|
||||
# Check that the recursion_limit value was included in the update call
|
||||
mock_config_repository.update.assert_called()
|
||||
# Get the call arguments
|
||||
call_args = mock_config_repository.update.call_args_list
|
||||
# Find the call that includes recursion_limit with value 50
|
||||
recursion_limit_found = False
|
||||
for args, _ in call_args:
|
||||
config_dict = args[0]
|
||||
if "recursion_limit" in config_dict and config_dict["recursion_limit"] == 50:
|
||||
recursion_limit_found = True
|
||||
break
|
||||
assert recursion_limit_found, f"recursion_limit=50 not found in update calls: {call_args}"
|
||||
|
||||
|
||||
def test_negative_recursion_limit():
|
||||
|
|
@ -141,70 +202,83 @@ def test_zero_recursion_limit():
|
|||
parse_arguments(["-m", "test message", "--recursion-limit", "0"])
|
||||
|
||||
|
||||
def test_config_settings(mock_dependencies):
|
||||
def test_config_settings(mock_dependencies, mock_config_repository):
|
||||
"""Test that various settings are correctly applied in global config."""
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from ra_aid.__main__ import main
|
||||
|
||||
_global_memory.clear()
|
||||
# Clear the mock repository before each test
|
||||
mock_config_repository.update.reset_mock()
|
||||
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
[
|
||||
"ra-aid",
|
||||
"-m",
|
||||
"test message",
|
||||
"--cowboy-mode",
|
||||
"--research-only",
|
||||
"--provider",
|
||||
"anthropic",
|
||||
"--model",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"--expert-provider",
|
||||
"openai",
|
||||
"--expert-model",
|
||||
"gpt-4",
|
||||
"--temperature",
|
||||
"0.7",
|
||||
"--disable-limit-tokens",
|
||||
],
|
||||
):
|
||||
main()
|
||||
config = _global_memory["config"]
|
||||
assert config["cowboy_mode"] is True
|
||||
assert config["research_only"] is True
|
||||
assert config["provider"] == "anthropic"
|
||||
assert config["model"] == "claude-3-7-sonnet-20250219"
|
||||
assert config["expert_provider"] == "openai"
|
||||
assert config["expert_model"] == "gpt-4"
|
||||
assert config["limit_tokens"] is False
|
||||
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
[
|
||||
"ra-aid",
|
||||
"-m",
|
||||
"test message",
|
||||
"--cowboy-mode",
|
||||
"--research-only",
|
||||
"--provider",
|
||||
"anthropic",
|
||||
"--model",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"--expert-provider",
|
||||
"openai",
|
||||
"--expert-model",
|
||||
"gpt-4",
|
||||
"--temperature",
|
||||
"0.7",
|
||||
"--disable-limit-tokens",
|
||||
],
|
||||
):
|
||||
main()
|
||||
# Verify config values are set via the update method
|
||||
mock_config_repository.update.assert_called()
|
||||
# Get the call arguments
|
||||
call_args = mock_config_repository.update.call_args_list
|
||||
|
||||
# Check for config values in the update calls
|
||||
for args, _ in call_args:
|
||||
config_dict = args[0]
|
||||
if "cowboy_mode" in config_dict:
|
||||
assert config_dict["cowboy_mode"] is True
|
||||
if "research_only" in config_dict:
|
||||
assert config_dict["research_only"] is True
|
||||
if "limit_tokens" in config_dict:
|
||||
assert config_dict["limit_tokens"] is False
|
||||
|
||||
# Check provider and model settings via set method
|
||||
mock_config_repository.set.assert_any_call("provider", "anthropic")
|
||||
mock_config_repository.set.assert_any_call("model", "claude-3-7-sonnet-20250219")
|
||||
mock_config_repository.set.assert_any_call("expert_provider", "openai")
|
||||
mock_config_repository.set.assert_any_call("expert_model", "gpt-4")
|
||||
|
||||
|
||||
def test_temperature_validation(mock_dependencies):
|
||||
def test_temperature_validation(mock_dependencies, mock_config_repository):
|
||||
"""Test that temperature argument is correctly passed to initialize_llm."""
|
||||
import sys
|
||||
from unittest.mock import patch, ANY
|
||||
|
||||
from ra_aid.__main__ import main
|
||||
|
||||
# Reset global memory for clean test
|
||||
_global_memory.clear()
|
||||
_global_memory["config"] = {}
|
||||
|
||||
# Test valid temperature (0.7)
|
||||
with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm:
|
||||
# Also patch any calls that would actually use the mocked initialize_llm function
|
||||
with patch("ra_aid.__main__.run_research_agent", return_value=None):
|
||||
with patch("ra_aid.__main__.run_planning_agent", return_value=None):
|
||||
with patch.object(
|
||||
sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"]
|
||||
):
|
||||
main()
|
||||
# Check if temperature was stored in config correctly
|
||||
assert _global_memory["config"]["temperature"] == 0.7
|
||||
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||
# Test valid temperature (0.7)
|
||||
with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm:
|
||||
# Also patch any calls that would actually use the mocked initialize_llm function
|
||||
with patch("ra_aid.__main__.run_research_agent", return_value=None):
|
||||
with patch("ra_aid.__main__.run_planning_agent", return_value=None):
|
||||
with patch.object(
|
||||
sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"]
|
||||
):
|
||||
main()
|
||||
# Verify that the temperature was set in the config repository
|
||||
mock_config_repository.set.assert_any_call("temperature", 0.7)
|
||||
|
||||
# Test invalid temperature (2.1)
|
||||
with pytest.raises(SystemExit):
|
||||
|
|
@ -230,61 +304,67 @@ def test_missing_message():
|
|||
assert args.message == "test"
|
||||
|
||||
|
||||
def test_research_model_provider_args(mock_dependencies):
|
||||
def test_research_model_provider_args(mock_dependencies, mock_config_repository):
|
||||
"""Test that research-specific model/provider args are correctly stored in config."""
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from ra_aid.__main__ import main
|
||||
|
||||
_global_memory.clear()
|
||||
# Reset mocks
|
||||
mock_config_repository.set.reset_mock()
|
||||
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
[
|
||||
"ra-aid",
|
||||
"-m",
|
||||
"test message",
|
||||
"--research-provider",
|
||||
"anthropic",
|
||||
"--research-model",
|
||||
"claude-3-haiku-20240307",
|
||||
"--planner-provider",
|
||||
"openai",
|
||||
"--planner-model",
|
||||
"gpt-4",
|
||||
],
|
||||
):
|
||||
main()
|
||||
config = _global_memory["config"]
|
||||
assert config["research_provider"] == "anthropic"
|
||||
assert config["research_model"] == "claude-3-haiku-20240307"
|
||||
assert config["planner_provider"] == "openai"
|
||||
assert config["planner_model"] == "gpt-4"
|
||||
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
[
|
||||
"ra-aid",
|
||||
"-m",
|
||||
"test message",
|
||||
"--research-provider",
|
||||
"anthropic",
|
||||
"--research-model",
|
||||
"claude-3-haiku-20240307",
|
||||
"--planner-provider",
|
||||
"openai",
|
||||
"--planner-model",
|
||||
"gpt-4",
|
||||
],
|
||||
):
|
||||
main()
|
||||
# Verify the mock repo's set method was called with the expected values
|
||||
mock_config_repository.set.assert_any_call("research_provider", "anthropic")
|
||||
mock_config_repository.set.assert_any_call("research_model", "claude-3-haiku-20240307")
|
||||
mock_config_repository.set.assert_any_call("planner_provider", "openai")
|
||||
mock_config_repository.set.assert_any_call("planner_model", "gpt-4")
|
||||
|
||||
|
||||
def test_planner_model_provider_args(mock_dependencies):
|
||||
def test_planner_model_provider_args(mock_dependencies, mock_config_repository):
|
||||
"""Test that planner provider/model args fall back to main config when not specified."""
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from ra_aid.__main__ import main
|
||||
|
||||
_global_memory.clear()
|
||||
# Reset mocks
|
||||
mock_config_repository.set.reset_mock()
|
||||
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"],
|
||||
):
|
||||
main()
|
||||
config = _global_memory["config"]
|
||||
assert config["planner_provider"] == "openai"
|
||||
assert config["planner_model"] == "gpt-4"
|
||||
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"],
|
||||
):
|
||||
main()
|
||||
# Verify the mock repo's set method was called with the expected values
|
||||
mock_config_repository.set.assert_any_call("planner_provider", "openai")
|
||||
mock_config_repository.set.assert_any_call("planner_model", "gpt-4")
|
||||
|
||||
|
||||
def test_use_aider_flag(mock_dependencies):
|
||||
def test_use_aider_flag(mock_dependencies, mock_config_repository):
|
||||
"""Test that use-aider flag is correctly stored in config."""
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
|
@ -292,44 +372,68 @@ def test_use_aider_flag(mock_dependencies):
|
|||
from ra_aid.__main__ import main
|
||||
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
|
||||
|
||||
_global_memory.clear()
|
||||
# Reset mocks
|
||||
mock_config_repository.update.reset_mock()
|
||||
|
||||
# Reset to default state
|
||||
set_modification_tools(False)
|
||||
|
||||
# Check default behavior (use_aider=False)
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
["ra-aid", "-m", "test message"],
|
||||
):
|
||||
main()
|
||||
config = _global_memory["config"]
|
||||
assert config.get("use_aider") is False
|
||||
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||
# Check default behavior (use_aider=False)
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
["ra-aid", "-m", "test message"],
|
||||
):
|
||||
main()
|
||||
# Verify use_aider is set to False in the update call
|
||||
mock_config_repository.update.assert_called()
|
||||
# Get the call arguments
|
||||
call_args = mock_config_repository.update.call_args_list
|
||||
# Find the call that includes use_aider
|
||||
use_aider_found = False
|
||||
for args, _ in call_args:
|
||||
config_dict = args[0]
|
||||
if "use_aider" in config_dict and config_dict["use_aider"] is False:
|
||||
use_aider_found = True
|
||||
break
|
||||
assert use_aider_found, f"use_aider=False not found in update calls: {call_args}"
|
||||
|
||||
# Check that file tools are enabled by default
|
||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||
assert "file_str_replace" in tool_names
|
||||
assert "put_complete_file_contents" in tool_names
|
||||
assert "run_programming_task" not in tool_names
|
||||
# Check that file tools are enabled by default
|
||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||
assert "file_str_replace" in tool_names
|
||||
assert "put_complete_file_contents" in tool_names
|
||||
assert "run_programming_task" not in tool_names
|
||||
|
||||
_global_memory.clear()
|
||||
# Reset mocks
|
||||
mock_config_repository.update.reset_mock()
|
||||
|
||||
# Check with --use-aider flag
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
["ra-aid", "-m", "test message", "--use-aider"],
|
||||
):
|
||||
main()
|
||||
config = _global_memory["config"]
|
||||
assert config.get("use_aider") is True
|
||||
# Check with --use-aider flag
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
["ra-aid", "-m", "test message", "--use-aider"],
|
||||
):
|
||||
main()
|
||||
# Verify use_aider is set to True in the update call
|
||||
mock_config_repository.update.assert_called()
|
||||
# Get the call arguments
|
||||
call_args = mock_config_repository.update.call_args_list
|
||||
# Find the call that includes use_aider
|
||||
use_aider_found = False
|
||||
for args, _ in call_args:
|
||||
config_dict = args[0]
|
||||
if "use_aider" in config_dict and config_dict["use_aider"] is True:
|
||||
use_aider_found = True
|
||||
break
|
||||
assert use_aider_found, f"use_aider=True not found in update calls: {call_args}"
|
||||
|
||||
# Check that run_programming_task is enabled
|
||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||
assert "file_str_replace" not in tool_names
|
||||
assert "put_complete_file_contents" not in tool_names
|
||||
assert "run_programming_task" in tool_names
|
||||
# Check that run_programming_task is enabled
|
||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||
assert "file_str_replace" not in tool_names
|
||||
assert "put_complete_file_contents" not in tool_names
|
||||
assert "run_programming_task" in tool_names
|
||||
|
||||
# Reset to default state for other tests
|
||||
set_modification_tools(False)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
|
@ -7,6 +8,36 @@ from ra_aid.tools.programmer import (
|
|||
run_programming_task,
|
||||
)
|
||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_config_repository():
|
||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Create a dictionary to simulate config
|
||||
config = {
|
||||
"recursion_limit": 2,
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"temperature": 0.01,
|
||||
"aider_config": "/path/to/config.yml"
|
||||
}
|
||||
|
||||
# Setup get_all method to return the config dict
|
||||
mock_repo.get_all.return_value = config
|
||||
|
||||
# Setup get method to return config values
|
||||
def get_config(key, default=None):
|
||||
return config.get(key, default)
|
||||
mock_repo.get.side_effect = get_config
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_related_files_repository():
|
||||
|
|
@ -125,13 +156,9 @@ def test_parse_aider_flags(input_flags, expected, description):
|
|||
assert result == expected, f"Failed test case: {description}"
|
||||
|
||||
|
||||
def test_aider_config_flag(mocker, mock_related_files_repository):
|
||||
def test_aider_config_flag(mocker, mock_config_repository, mock_related_files_repository):
|
||||
"""Test that aider config flag is properly included in the command when specified."""
|
||||
# Mock config in global memory but not related files (using repository now)
|
||||
mock_memory = {
|
||||
"config": {"aider_config": "/path/to/config.yml"},
|
||||
}
|
||||
mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory)
|
||||
# Config is mocked by mock_config_repository fixture
|
||||
|
||||
# Mock the run_interactive_command to capture the command that would be run
|
||||
mock_run = mocker.patch(
|
||||
|
|
@ -146,15 +173,14 @@ def test_aider_config_flag(mocker, mock_related_files_repository):
|
|||
assert args[config_index + 1] == "/path/to/config.yml"
|
||||
|
||||
|
||||
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_related_files_repository):
|
||||
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_config_repository, mock_related_files_repository):
|
||||
"""Test path normalization and deduplication in run_programming_task."""
|
||||
# Create a temporary test file
|
||||
test_file = tmp_path / "test.py"
|
||||
test_file.write_text("")
|
||||
new_file = tmp_path / "new.py"
|
||||
|
||||
# Mock dependencies - only need to mock config part of global memory now
|
||||
mocker.patch("ra_aid.tools.programmer._global_memory", {"config": {}})
|
||||
# Config is mocked by mock_config_repository fixture
|
||||
mocker.patch(
|
||||
"ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from ra_aid.tools.agent import (
|
|||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -43,6 +44,34 @@ def mock_related_files_repository():
|
|||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_config_repository():
|
||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Create a dictionary to simulate config
|
||||
config = {
|
||||
"recursion_limit": 2,
|
||||
"provider": "anthropic",
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"temperature": 0.01
|
||||
}
|
||||
|
||||
# Setup get_all method to return the config dict
|
||||
mock_repo.get_all.return_value = config
|
||||
|
||||
# Setup get method to return config values
|
||||
def get_config(key, default=None):
|
||||
return config.get(key, default)
|
||||
mock_repo.get.side_effect = get_config
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_work_log_repository():
|
||||
"""Mock the WorkLogRepository to avoid database operations during tests"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
|
||||
from ra_aid.tools.shell import run_shell_command
|
||||
|
||||
|
||||
|
|
@ -25,9 +25,38 @@ def mock_run_interactive():
|
|||
yield mock
|
||||
|
||||
|
||||
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive):
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_config_repository():
|
||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Create a dictionary to simulate config
|
||||
config = {
|
||||
"cowboy_mode": False
|
||||
}
|
||||
|
||||
# Setup get method to return config values
|
||||
def get_config(key, default=None):
|
||||
return config.get(key, default)
|
||||
mock_repo.get.side_effect = get_config
|
||||
|
||||
# Setup set method to update config values
|
||||
def set_config(key, value):
|
||||
config[key] = value
|
||||
mock_repo.set.side_effect = set_config
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||
"""Test shell command execution in cowboy mode (no approval)"""
|
||||
_global_memory["config"] = {"cowboy_mode": True}
|
||||
# Set cowboy mode to True using the repository
|
||||
mock_config_repository.set("cowboy_mode", True)
|
||||
|
||||
result = run_shell_command.invoke({"command": "echo test"})
|
||||
|
||||
|
|
@ -37,9 +66,10 @@ def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interacti
|
|||
mock_prompt.ask.assert_not_called()
|
||||
|
||||
|
||||
def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive):
|
||||
def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||
"""Test that cowboy mode displays a properly formatted cowboy message with correct spacing"""
|
||||
_global_memory["config"] = {"cowboy_mode": True}
|
||||
# Set cowboy mode to True using the repository
|
||||
mock_config_repository.set("cowboy_mode", True)
|
||||
|
||||
with patch("ra_aid.tools.shell.get_cowboy_message") as mock_get_message:
|
||||
mock_get_message.return_value = "🤠 Test cowboy message!"
|
||||
|
|
@ -53,10 +83,11 @@ def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_intera
|
|||
|
||||
|
||||
def test_shell_command_interactive_approved(
|
||||
mock_console, mock_prompt, mock_run_interactive
|
||||
mock_console, mock_prompt, mock_run_interactive, mock_config_repository
|
||||
):
|
||||
"""Test shell command execution with interactive approval"""
|
||||
_global_memory["config"] = {"cowboy_mode": False}
|
||||
# Set cowboy mode to False using the repository
|
||||
mock_config_repository.set("cowboy_mode", False)
|
||||
mock_prompt.ask.return_value = "y"
|
||||
|
||||
result = run_shell_command.invoke({"command": "echo test"})
|
||||
|
|
@ -74,10 +105,11 @@ def test_shell_command_interactive_approved(
|
|||
|
||||
|
||||
def test_shell_command_interactive_rejected(
|
||||
mock_console, mock_prompt, mock_run_interactive
|
||||
mock_console, mock_prompt, mock_run_interactive, mock_config_repository
|
||||
):
|
||||
"""Test shell command rejection in interactive mode"""
|
||||
_global_memory["config"] = {"cowboy_mode": False}
|
||||
# Set cowboy mode to False using the repository
|
||||
mock_config_repository.set("cowboy_mode", False)
|
||||
mock_prompt.ask.return_value = "n"
|
||||
|
||||
result = run_shell_command.invoke({"command": "echo test"})
|
||||
|
|
@ -95,9 +127,10 @@ def test_shell_command_interactive_rejected(
|
|||
mock_run_interactive.assert_not_called()
|
||||
|
||||
|
||||
def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive):
|
||||
def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||
"""Test handling of shell command execution errors"""
|
||||
_global_memory["config"] = {"cowboy_mode": True}
|
||||
# Set cowboy mode to True using the repository
|
||||
mock_config_repository.set("cowboy_mode", True)
|
||||
mock_run_interactive.side_effect = Exception("Command failed")
|
||||
|
||||
result = run_shell_command.invoke({"command": "invalid command"})
|
||||
|
|
|
|||
Loading…
Reference in New Issue