config repository
This commit is contained in:
parent
3e68dd3fa6
commit
5bd8c76a22
|
|
@ -58,6 +58,10 @@ from ra_aid.database.repositories.related_files_repository import (
|
||||||
from ra_aid.database.repositories.work_log_repository import (
|
from ra_aid.database.repositories.work_log_repository import (
|
||||||
WorkLogRepositoryManager
|
WorkLogRepositoryManager
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.config_repository import (
|
||||||
|
ConfigRepositoryManager,
|
||||||
|
get_config_repository
|
||||||
|
)
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
|
|
@ -77,7 +81,7 @@ from ra_aid.prompts.chat_prompts import CHAT_PROMPT
|
||||||
from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_CHAT
|
from ra_aid.prompts.web_research_prompts import WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||||
from ra_aid.tool_configs import get_chat_tools, set_modification_tools
|
from ra_aid.tool_configs import get_chat_tools, set_modification_tools
|
||||||
from ra_aid.tools.human import ask_human
|
from ra_aid.tools.human import ask_human
|
||||||
from ra_aid.tools.memory import _global_memory, get_memory_value
|
from ra_aid.tools.memory import get_memory_value
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -338,7 +342,7 @@ implementation_memory = MemorySaver()
|
||||||
|
|
||||||
def is_informational_query() -> bool:
|
def is_informational_query() -> bool:
|
||||||
"""Determine if the current query is informational based on config settings."""
|
"""Determine if the current query is informational based on config settings."""
|
||||||
return _global_memory.get("config", {}).get("research_only", False)
|
return get_config_repository().get("research_only", False)
|
||||||
|
|
||||||
|
|
||||||
def is_stage_requested(stage: str) -> bool:
|
def is_stage_requested(stage: str) -> bool:
|
||||||
|
|
@ -404,13 +408,17 @@ def main():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Database migration error: {str(e)}")
|
logger.error(f"Database migration error: {str(e)}")
|
||||||
|
|
||||||
|
# Initialize empty config dictionary to be populated later
|
||||||
|
config = {}
|
||||||
|
|
||||||
# Initialize repositories with database connection
|
# Initialize repositories with database connection
|
||||||
with KeyFactRepositoryManager(db) as key_fact_repo, \
|
with KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||||
WorkLogRepositoryManager() as work_log_repo:
|
WorkLogRepositoryManager() as work_log_repo, \
|
||||||
|
ConfigRepositoryManager(config) as config_repo:
|
||||||
# This initializes all repositories and makes them available via their respective get methods
|
# This initializes all repositories and makes them available via their respective get methods
|
||||||
logger.debug("Initialized KeyFactRepository")
|
logger.debug("Initialized KeyFactRepository")
|
||||||
logger.debug("Initialized KeySnippetRepository")
|
logger.debug("Initialized KeySnippetRepository")
|
||||||
|
|
@ -418,6 +426,7 @@ def main():
|
||||||
logger.debug("Initialized ResearchNoteRepository")
|
logger.debug("Initialized ResearchNoteRepository")
|
||||||
logger.debug("Initialized RelatedFilesRepository")
|
logger.debug("Initialized RelatedFilesRepository")
|
||||||
logger.debug("Initialized WorkLogRepository")
|
logger.debug("Initialized WorkLogRepository")
|
||||||
|
logger.debug("Initialized ConfigRepository")
|
||||||
|
|
||||||
# Check dependencies before proceeding
|
# Check dependencies before proceeding
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
|
|
@ -520,13 +529,13 @@ def main():
|
||||||
"limit_tokens": args.disable_limit_tokens,
|
"limit_tokens": args.disable_limit_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store config in global memory
|
# Store config in repository
|
||||||
_global_memory["config"] = config
|
config_repo.update(config)
|
||||||
_global_memory["config"]["provider"] = args.provider
|
config_repo.set("provider", args.provider)
|
||||||
_global_memory["config"]["model"] = args.model
|
config_repo.set("model", args.model)
|
||||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
config_repo.set("expert_provider", args.expert_provider)
|
||||||
_global_memory["config"]["expert_model"] = args.expert_model
|
config_repo.set("expert_model", args.expert_model)
|
||||||
_global_memory["config"]["temperature"] = args.temperature
|
config_repo.set("temperature", args.temperature)
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
@ -594,33 +603,27 @@ def main():
|
||||||
"test_cmd_timeout": args.test_cmd_timeout,
|
"test_cmd_timeout": args.test_cmd_timeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store config in global memory for access by is_informational_query
|
# Store config in repository
|
||||||
_global_memory["config"] = config
|
config_repo.update(config)
|
||||||
|
|
||||||
# Store base provider/model configuration
|
# Store base provider/model configuration
|
||||||
_global_memory["config"]["provider"] = args.provider
|
config_repo.set("provider", args.provider)
|
||||||
_global_memory["config"]["model"] = args.model
|
config_repo.set("model", args.model)
|
||||||
|
|
||||||
# Store expert provider/model (no fallback)
|
# Store expert provider/model (no fallback)
|
||||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
config_repo.set("expert_provider", args.expert_provider)
|
||||||
_global_memory["config"]["expert_model"] = args.expert_model
|
config_repo.set("expert_model", args.expert_model)
|
||||||
|
|
||||||
# Store planner config with fallback to base values
|
# Store planner config with fallback to base values
|
||||||
_global_memory["config"]["planner_provider"] = (
|
config_repo.set("planner_provider", args.planner_provider or args.provider)
|
||||||
args.planner_provider or args.provider
|
config_repo.set("planner_model", args.planner_model or args.model)
|
||||||
)
|
|
||||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
|
||||||
|
|
||||||
# Store research config with fallback to base values
|
# Store research config with fallback to base values
|
||||||
_global_memory["config"]["research_provider"] = (
|
config_repo.set("research_provider", args.research_provider or args.provider)
|
||||||
args.research_provider or args.provider
|
config_repo.set("research_model", args.research_model or args.model)
|
||||||
)
|
|
||||||
_global_memory["config"]["research_model"] = (
|
|
||||||
args.research_model or args.model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store temperature in global config
|
# Store temperature in config
|
||||||
_global_memory["config"]["temperature"] = args.temperature
|
config_repo.set("temperature", args.temperature)
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
|
||||||
|
|
@ -94,11 +94,11 @@ from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||||
from ra_aid.tools.memory import (
|
from ra_aid.tools.memory import (
|
||||||
_global_memory,
|
|
||||||
get_memory_value,
|
get_memory_value,
|
||||||
get_related_files,
|
get_related_files,
|
||||||
log_work_event,
|
log_work_event,
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -302,7 +302,7 @@ def create_agent(
|
||||||
config['limit_tokens'] = False.
|
config['limit_tokens'] = False.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config = _global_memory.get("config", {})
|
config = get_config_repository().get_all()
|
||||||
max_input_tokens = (
|
max_input_tokens = (
|
||||||
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
|
||||||
)
|
)
|
||||||
|
|
@ -319,7 +319,7 @@ def create_agent(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Default to REACT agent if provider/model detection fails
|
# Default to REACT agent if provider/model detection fails
|
||||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||||
config = _global_memory.get("config", {})
|
config = get_config_repository().get_all()
|
||||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
|
||||||
return create_react_agent(model, tools, **agent_kwargs)
|
return create_react_agent(model, tools, **agent_kwargs)
|
||||||
|
|
@ -443,7 +443,7 @@ def run_research_agent(
|
||||||
new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "",
|
new_project_hints=NEW_PROJECT_HINTS if project_info.is_new else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = get_config_repository().get_all() if not config else config
|
||||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
run_config = {
|
run_config = {
|
||||||
"configurable": {"thread_id": thread_id},
|
"configurable": {"thread_id": thread_id},
|
||||||
|
|
@ -575,7 +575,7 @@ def run_web_research_agent(
|
||||||
related_files=related_files,
|
related_files=related_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = get_config_repository().get_all() if not config else config
|
||||||
|
|
||||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
run_config = {
|
run_config = {
|
||||||
|
|
@ -709,7 +709,7 @@ def run_planning_agent(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = get_config_repository().get_all() if not config else config
|
||||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
run_config = {
|
run_config = {
|
||||||
"configurable": {"thread_id": thread_id},
|
"configurable": {"thread_id": thread_id},
|
||||||
|
|
@ -824,7 +824,7 @@ def run_task_implementation_agent(
|
||||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||||
human_section=(
|
human_section=(
|
||||||
HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
||||||
if _global_memory.get("config", {}).get("hil", False)
|
if get_config_repository().get("hil", False)
|
||||||
else ""
|
else ""
|
||||||
),
|
),
|
||||||
web_research_section=(
|
web_research_section=(
|
||||||
|
|
@ -834,7 +834,7 @@ def run_task_implementation_agent(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
config = _global_memory.get("config", {}) if not config else config
|
config = get_config_repository().get_all() if not config else config
|
||||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||||
run_config = {
|
run_config = {
|
||||||
"configurable": {"thread_id": thread_id},
|
"configurable": {"thread_id": thread_id},
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,10 @@ logger = logging.getLogger(__name__)
|
||||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
||||||
from ra_aid.tools.memory import log_work_event, _global_memory
|
from ra_aid.tools.memory import log_work_event
|
||||||
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
@ -149,7 +150,7 @@ def run_key_facts_gc_agent() -> None:
|
||||||
formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()])
|
formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()])
|
||||||
|
|
||||||
# Retrieve configuration
|
# Retrieve configuration
|
||||||
llm_config = _global_memory.get("config", {})
|
llm_config = get_config_repository().get_all()
|
||||||
|
|
||||||
# Initialize the LLM model
|
# Initialize the LLM model
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,10 @@ from rich.panel import Panel
|
||||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||||
from ra_aid.tools.memory import log_work_event, _global_memory
|
from ra_aid.tools.memory import log_work_event
|
||||||
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
@ -153,7 +154,7 @@ def run_key_snippets_gc_agent() -> None:
|
||||||
])
|
])
|
||||||
|
|
||||||
# Retrieve configuration
|
# Retrieve configuration
|
||||||
llm_config = _global_memory.get("config", {})
|
llm_config = get_config_repository().get_all()
|
||||||
|
|
||||||
# Initialize the LLM model
|
# Initialize the LLM model
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,10 @@ logger = logging.getLogger(__name__)
|
||||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
||||||
from ra_aid.tools.memory import log_work_event, _global_memory
|
from ra_aid.tools.memory import log_work_event
|
||||||
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
@ -154,7 +155,7 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None:
|
||||||
formatted_notes = "\n".join([f"Note #{k}: {v}" for k, v in notes_dict.items()])
|
formatted_notes = "\n".join([f"Note #{k}: {v}" for k, v in notes_dict.items()])
|
||||||
|
|
||||||
# Retrieve configuration
|
# Retrieve configuration
|
||||||
llm_config = _global_memory.get("config", {})
|
llm_config = get_config_repository().get_all()
|
||||||
|
|
||||||
# Initialize the LLM model
|
# Initialize the LLM model
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,165 @@
|
||||||
|
"""Repository for managing configuration values."""
|
||||||
|
|
||||||
|
import contextvars
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
# Create contextvar to hold the ConfigRepository instance
|
||||||
|
config_repo_var = contextvars.ContextVar("config_repo", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigRepository:
|
||||||
|
"""
|
||||||
|
Repository for managing configuration values in memory.
|
||||||
|
|
||||||
|
This class provides methods to get, set, update, and retrieve all configuration values.
|
||||||
|
It does not require database models and operates entirely in memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
|
||||||
|
"""
|
||||||
|
Initialize the ConfigRepository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_config: Optional dictionary of initial configuration values
|
||||||
|
"""
|
||||||
|
self._config: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Initialize with default values from config.py
|
||||||
|
from ra_aid.config import (
|
||||||
|
DEFAULT_RECURSION_LIMIT,
|
||||||
|
DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||||
|
DEFAULT_MAX_TOOL_FAILURES,
|
||||||
|
FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
|
RETRY_FALLBACK_COUNT,
|
||||||
|
DEFAULT_TEST_CMD_TIMEOUT,
|
||||||
|
VALID_PROVIDERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._config = {
|
||||||
|
"recursion_limit": DEFAULT_RECURSION_LIMIT,
|
||||||
|
"max_test_cmd_retries": DEFAULT_MAX_TEST_CMD_RETRIES,
|
||||||
|
"max_tool_failures": DEFAULT_MAX_TOOL_FAILURES,
|
||||||
|
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
|
"retry_fallback_count": RETRY_FALLBACK_COUNT,
|
||||||
|
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
|
||||||
|
"valid_providers": VALID_PROVIDERS,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update with any provided initial configuration
|
||||||
|
if initial_config:
|
||||||
|
self._config.update(initial_config)
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
Get a configuration value by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Configuration key to retrieve
|
||||||
|
default: Default value to return if key is not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value or default if not found
|
||||||
|
"""
|
||||||
|
return self._config.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any) -> None:
|
||||||
|
"""
|
||||||
|
Set a configuration value by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Configuration key to set
|
||||||
|
value: Value to set for the key
|
||||||
|
"""
|
||||||
|
self._config[key] = value
|
||||||
|
|
||||||
|
def update(self, config_dict: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Update multiple configuration values at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: Dictionary of configuration key-value pairs to update
|
||||||
|
"""
|
||||||
|
self._config.update(config_dict)
|
||||||
|
|
||||||
|
def get_all(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get all configuration values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing all configuration values
|
||||||
|
"""
|
||||||
|
return self._config.copy()
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigRepositoryManager:
|
||||||
|
"""
|
||||||
|
Context manager for ConfigRepository.
|
||||||
|
|
||||||
|
This class provides a context manager interface for ConfigRepository,
|
||||||
|
using the contextvars approach for thread safety.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with ConfigRepositoryManager() as repo:
|
||||||
|
# Use the repository
|
||||||
|
value = repo.get("key")
|
||||||
|
repo.set("key", new_value)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
|
||||||
|
"""
|
||||||
|
Initialize the ConfigRepositoryManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_config: Optional dictionary of initial configuration values
|
||||||
|
"""
|
||||||
|
self.initial_config = initial_config
|
||||||
|
|
||||||
|
def __enter__(self) -> 'ConfigRepository':
|
||||||
|
"""
|
||||||
|
Initialize the ConfigRepository and return it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConfigRepository: The initialized repository
|
||||||
|
"""
|
||||||
|
repo = ConfigRepository(self.initial_config)
|
||||||
|
config_repo_var.set(repo)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type],
|
||||||
|
exc_val: Optional[Exception],
|
||||||
|
exc_tb: Optional[object],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reset the repository when exiting the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type: The exception type if an exception was raised
|
||||||
|
exc_val: The exception value if an exception was raised
|
||||||
|
exc_tb: The traceback if an exception was raised
|
||||||
|
"""
|
||||||
|
# Reset the contextvar to None
|
||||||
|
config_repo_var.set(None)
|
||||||
|
|
||||||
|
# Don't suppress exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_repository() -> ConfigRepository:
|
||||||
|
"""
|
||||||
|
Get the current ConfigRepository instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConfigRepository: The current repository instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no repository is set in the current context
|
||||||
|
"""
|
||||||
|
repo = config_repo_var.get()
|
||||||
|
if repo is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ConfigRepository not initialized in current context. "
|
||||||
|
"Make sure to use ConfigRepositoryManager."
|
||||||
|
)
|
||||||
|
return repo
|
||||||
|
|
@ -27,6 +27,7 @@ from ra_aid.tools.agent import (
|
||||||
request_web_research,
|
request_web_research,
|
||||||
)
|
)
|
||||||
from ra_aid.tools.memory import plan_implementation_completed
|
from ra_aid.tools.memory import plan_implementation_completed
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
||||||
|
|
||||||
def set_modification_tools(use_aider=False):
|
def set_modification_tools(use_aider=False):
|
||||||
|
|
@ -98,13 +99,11 @@ def get_all_tools() -> list[BaseTool]:
|
||||||
|
|
||||||
|
|
||||||
# Define constant tool groups
|
# Define constant tool groups
|
||||||
# Get config from global memory for use_aider value
|
# Get config from repository for use_aider value
|
||||||
_config = {}
|
_config = {}
|
||||||
try:
|
try:
|
||||||
from ra_aid.tools.memory import _global_memory
|
_config = get_config_repository().get_all()
|
||||||
|
except (ImportError, RuntimeError):
|
||||||
_config = _global_memory.get("config", {})
|
|
||||||
except ImportError:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
READ_ONLY_TOOLS = get_read_only_tools(use_aider=_config.get("use_aider", False))
|
READ_ONLY_TOOLS = get_read_only_tools(use_aider=_config.get("use_aider", False))
|
||||||
|
|
@ -139,10 +138,8 @@ def get_research_tools(
|
||||||
# Get config for use_aider value
|
# Get config for use_aider value
|
||||||
use_aider = False
|
use_aider = False
|
||||||
try:
|
try:
|
||||||
from ra_aid.tools.memory import _global_memory
|
use_aider = get_config_repository().get("use_aider", False)
|
||||||
|
except (ImportError, RuntimeError):
|
||||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
|
|
@ -180,10 +177,8 @@ def get_planning_tools(
|
||||||
# Get config for use_aider value
|
# Get config for use_aider value
|
||||||
use_aider = False
|
use_aider = False
|
||||||
try:
|
try:
|
||||||
from ra_aid.tools.memory import _global_memory
|
use_aider = get_config_repository().get("use_aider", False)
|
||||||
|
except (ImportError, RuntimeError):
|
||||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
|
|
@ -219,10 +214,8 @@ def get_implementation_tools(
|
||||||
# Get config for use_aider value
|
# Get config for use_aider value
|
||||||
use_aider = False
|
use_aider = False
|
||||||
try:
|
try:
|
||||||
from ra_aid.tools.memory import _global_memory
|
use_aider = get_config_repository().get("use_aider", False)
|
||||||
|
except (ImportError, RuntimeError):
|
||||||
use_aider = _global_memory.get("config", {}).get("use_aider", False)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
|
|
@ -285,4 +278,4 @@ def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = Fal
|
||||||
if web_research_enabled:
|
if web_research_enabled:
|
||||||
tools.append(request_web_research)
|
tools.append(request_web_research)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
@ -18,13 +18,13 @@ from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||||
from ra_aid.tools.memory import _global_memory
|
|
||||||
|
|
||||||
from ..console import print_task_header
|
from ..console import print_task_header
|
||||||
from ..llm import initialize_llm
|
from ..llm import initialize_llm
|
||||||
|
|
@ -52,7 +52,7 @@ def request_research(query: str) -> ResearchResult:
|
||||||
query: The research question or project description
|
query: The research question or project description
|
||||||
"""
|
"""
|
||||||
# Initialize model from config
|
# Initialize model from config
|
||||||
config = _global_memory.get("config", {})
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-7-sonnet-20250219"),
|
config.get("model", "claude-3-7-sonnet-20250219"),
|
||||||
|
|
@ -165,7 +165,7 @@ def request_web_research(query: str) -> ResearchResult:
|
||||||
query: The research question or project description
|
query: The research question or project description
|
||||||
"""
|
"""
|
||||||
# Initialize model from config
|
# Initialize model from config
|
||||||
config = _global_memory.get("config", {})
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-7-sonnet-20250219"),
|
config.get("model", "claude-3-7-sonnet-20250219"),
|
||||||
|
|
@ -246,7 +246,7 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]:
|
||||||
query: The research question or project description
|
query: The research question or project description
|
||||||
"""
|
"""
|
||||||
# Initialize model from config
|
# Initialize model from config
|
||||||
config = _global_memory.get("config", {})
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-7-sonnet-20250219"),
|
config.get("model", "claude-3-7-sonnet-20250219"),
|
||||||
|
|
@ -335,7 +335,7 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
task_spec: REQUIRED The full task specification (markdown format, typically one part of the overall plan)
|
task_spec: REQUIRED The full task specification (markdown format, typically one part of the overall plan)
|
||||||
"""
|
"""
|
||||||
# Initialize model from config
|
# Initialize model from config
|
||||||
config = _global_memory.get("config", {})
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
config.get("model", "claude-3-5-sonnet-20241022"),
|
||||||
|
|
@ -474,7 +474,7 @@ def request_implementation(task_spec: str) -> str:
|
||||||
task_spec: The task specification to plan implementation for
|
task_spec: The task specification to plan implementation for
|
||||||
"""
|
"""
|
||||||
# Initialize model from config
|
# Initialize model from config
|
||||||
config = _global_memory.get("config", {})
|
config = get_config_repository().get_all()
|
||||||
model = initialize_llm(
|
model = initialize_llm(
|
||||||
config.get("provider", "anthropic"),
|
config.get("provider", "anthropic"),
|
||||||
config.get("model", "claude-3-5-sonnet-20241022"),
|
config.get("model", "claude-3-5-sonnet-20241022"),
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,12 @@ from ..database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ..database.repositories.related_files_repository import get_related_files_repository
|
from ..database.repositories.related_files_repository import get_related_files_repository
|
||||||
from ..database.repositories.research_note_repository import get_research_note_repository
|
from ..database.repositories.research_note_repository import get_research_note_repository
|
||||||
|
from ..database.repositories.config_repository import get_config_repository
|
||||||
from ..llm import initialize_expert_llm
|
from ..llm import initialize_expert_llm
|
||||||
from ..model_formatters import format_key_facts_dict
|
from ..model_formatters import format_key_facts_dict
|
||||||
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ..model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
from ..model_formatters.research_notes_formatter import format_research_notes_dict
|
from ..model_formatters.research_notes_formatter import format_research_notes_dict
|
||||||
from .memory import _global_memory, get_memory_value
|
from .memory import get_memory_value
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
_model = None
|
_model = None
|
||||||
|
|
@ -27,9 +28,9 @@ def get_model():
|
||||||
global _model
|
global _model
|
||||||
try:
|
try:
|
||||||
if _model is None:
|
if _model is None:
|
||||||
config = _global_memory["config"]
|
config_repo = get_config_repository()
|
||||||
provider = config.get("expert_provider") or config.get("provider")
|
provider = config_repo.get("expert_provider") or config_repo.get("provider")
|
||||||
model = config.get("expert_model") or config.get("model")
|
model = config_repo.get("expert_model") or config_repo.get("model")
|
||||||
_model = initialize_expert_llm(provider, model)
|
_model = initialize_expert_llm(provider, model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_model = None
|
_model = None
|
||||||
|
|
|
||||||
|
|
@ -58,10 +58,10 @@ def ask_human(question: str) -> str:
|
||||||
# Record human response in database
|
# Record human response in database
|
||||||
try:
|
try:
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
||||||
# Determine the source based on context
|
# Determine the source based on context
|
||||||
config = _global_memory.get("config", {})
|
config = get_config_repository().get_all()
|
||||||
# If chat_mode is enabled, use 'chat', otherwise determine if hil mode is active
|
# If chat_mode is enabled, use 'chat', otherwise determine if hil mode is active
|
||||||
if config.get("chat_mode", False):
|
if config.get("chat_mode", False):
|
||||||
source = "chat"
|
source = "chat"
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,8 @@ from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
|
from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
|
||||||
from ra_aid.proc.interactive import run_interactive_command
|
from ra_aid.proc.interactive import run_interactive_command
|
||||||
from ra_aid.text.processing import truncate_output
|
from ra_aid.text.processing import truncate_output
|
||||||
from ra_aid.tools.memory import _global_memory, log_work_event
|
from ra_aid.tools.memory import log_work_event
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
@ -107,8 +108,9 @@ def run_programming_task(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add config file if specified
|
# Add config file if specified
|
||||||
if "config" in _global_memory and _global_memory["config"].get("aider_config"):
|
config = get_config_repository().get_all()
|
||||||
command.extend(["--config", _global_memory["config"]["aider_config"]])
|
if config.get("aider_config"):
|
||||||
|
command.extend(["--config", config["aider_config"]])
|
||||||
|
|
||||||
# if environment variable AIDER_FLAGS exists then parse
|
# if environment variable AIDER_FLAGS exists then parse
|
||||||
if "AIDER_FLAGS" in os.environ:
|
if "AIDER_FLAGS" in os.environ:
|
||||||
|
|
@ -147,8 +149,9 @@ def run_programming_task(
|
||||||
# Run the command interactively
|
# Run the command interactively
|
||||||
print()
|
print()
|
||||||
# Get provider/model specific latency coefficient
|
# Get provider/model specific latency coefficient
|
||||||
provider = _global_memory.get("config", {}).get("provider", "")
|
config = get_config_repository().get_all()
|
||||||
model = _global_memory.get("config", {}).get("model", "")
|
provider = config.get("provider", "")
|
||||||
|
model = config.get("model", "")
|
||||||
latency = (
|
latency = (
|
||||||
models_params.get(provider, {})
|
models_params.get(provider, {})
|
||||||
.get(model, {})
|
.get(model, {})
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from ra_aid.console.cowboy_messages import get_cowboy_message
|
||||||
from ra_aid.proc.interactive import run_interactive_command
|
from ra_aid.proc.interactive import run_interactive_command
|
||||||
from ra_aid.text.processing import truncate_output
|
from ra_aid.text.processing import truncate_output
|
||||||
from ra_aid.tools.memory import _global_memory, log_work_event
|
from ra_aid.tools.memory import _global_memory, log_work_event
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -46,7 +47,7 @@ def run_shell_command(
|
||||||
4. Add flags e.g. git --no-pager in order to reduce interaction required by the human.
|
4. Add flags e.g. git --no-pager in order to reduce interaction required by the human.
|
||||||
"""
|
"""
|
||||||
# Check if we need approval
|
# Check if we need approval
|
||||||
cowboy_mode = _global_memory.get("config", {}).get("cowboy_mode", False)
|
cowboy_mode = get_config_repository().get("cowboy_mode", False)
|
||||||
|
|
||||||
if cowboy_mode:
|
if cowboy_mode:
|
||||||
console.print("")
|
console.print("")
|
||||||
|
|
@ -74,7 +75,7 @@ def run_shell_command(
|
||||||
"success": False,
|
"success": False,
|
||||||
}
|
}
|
||||||
elif response == "c":
|
elif response == "c":
|
||||||
_global_memory["config"]["cowboy_mode"] = True
|
get_config_repository().set("cowboy_mode", True)
|
||||||
console.print("")
|
console.print("")
|
||||||
console.print(" " + get_cowboy_message())
|
console.print(" " + get_cowboy_message())
|
||||||
console.print("")
|
console.print("")
|
||||||
|
|
@ -96,4 +97,4 @@ def run_shell_command(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print()
|
print()
|
||||||
console.print(Panel(str(e), title="❌ Error", border_style="red"))
|
console.print(Panel(str(e), title="❌ Error", border_style="red"))
|
||||||
return {"output": str(e), "return_code": 1, "success": False}
|
return {"output": str(e), "return_code": 1, "success": False}
|
||||||
|
|
@ -7,10 +7,25 @@ ensuring consistent test environments and proper isolation.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_config_repository():
|
||||||
|
"""Mock the config repository."""
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
||||||
|
repo = MagicMock()
|
||||||
|
# Default config values
|
||||||
|
config_values = {"recursion_limit": 2}
|
||||||
|
repo.get_all.return_value = config_values
|
||||||
|
repo.get.side_effect = lambda key, default=None: config_values.get(key, default)
|
||||||
|
get_config_repository.return_value = repo
|
||||||
|
yield repo
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def isolated_db_environment(tmp_path, monkeypatch):
|
def isolated_db_environment(tmp_path, monkeypatch):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Unit tests for agent_utils.py."""
|
"""Unit tests for agent_utils.py."""
|
||||||
|
|
||||||
from typing import Any, Dict, Literal
|
from typing import Any, Dict, Literal
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -19,6 +19,7 @@ from ra_aid.agent_utils import (
|
||||||
state_modifier,
|
state_modifier,
|
||||||
)
|
)
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
|
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -29,40 +30,70 @@ def mock_model():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_memory():
|
def mock_config_repository():
|
||||||
"""Fixture providing a mock global memory store."""
|
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||||
with patch("ra_aid.agent_utils._global_memory") as mock_mem:
|
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||||
mock_mem.get.return_value = {}
|
# Setup a mock repository
|
||||||
yield mock_mem
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a dictionary to simulate config
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
# Setup get method to return config values
|
||||||
|
def get_config(key, default=None):
|
||||||
|
return config.get(key, default)
|
||||||
|
mock_repo.get.side_effect = get_config
|
||||||
|
|
||||||
|
# Setup get_all method to return all config values
|
||||||
|
mock_repo.get_all.return_value = config
|
||||||
|
|
||||||
|
# Setup set method to update config values
|
||||||
|
def set_config(key, value):
|
||||||
|
config[key] = value
|
||||||
|
mock_repo.set.side_effect = set_config
|
||||||
|
|
||||||
|
# Setup update method to update multiple config values
|
||||||
|
def update_config(update_dict):
|
||||||
|
config.update(update_dict)
|
||||||
|
mock_repo.update.side_effect = update_config
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_anthropic(mock_memory):
|
def test_get_model_token_limit_anthropic(mock_config_repository):
|
||||||
"""Test get_model_token_limit with Anthropic model."""
|
"""Test get_model_token_limit with Anthropic model."""
|
||||||
config = {"provider": "anthropic", "model": "claude2"}
|
config = {"provider": "anthropic", "model": "claude2"}
|
||||||
|
mock_config_repository.update(config)
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_openai(mock_memory):
|
def test_get_model_token_limit_openai(mock_config_repository):
|
||||||
"""Test get_model_token_limit with OpenAI model."""
|
"""Test get_model_token_limit with OpenAI model."""
|
||||||
config = {"provider": "openai", "model": "gpt-4"}
|
config = {"provider": "openai", "model": "gpt-4"}
|
||||||
|
mock_config_repository.update(config)
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_unknown(mock_memory):
|
def test_get_model_token_limit_unknown(mock_config_repository):
|
||||||
"""Test get_model_token_limit with unknown provider/model."""
|
"""Test get_model_token_limit with unknown provider/model."""
|
||||||
config = {"provider": "unknown", "model": "unknown-model"}
|
config = {"provider": "unknown", "model": "unknown-model"}
|
||||||
|
mock_config_repository.update(config)
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit is None
|
assert token_limit is None
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_missing_config(mock_memory):
|
def test_get_model_token_limit_missing_config(mock_config_repository):
|
||||||
"""Test get_model_token_limit with missing configuration."""
|
"""Test get_model_token_limit with missing configuration."""
|
||||||
config = {}
|
config = {}
|
||||||
|
mock_config_repository.update(config)
|
||||||
|
|
||||||
token_limit = get_model_token_limit(config, "default")
|
token_limit = get_model_token_limit(config, "default")
|
||||||
assert token_limit is None
|
assert token_limit is None
|
||||||
|
|
@ -108,9 +139,9 @@ def test_get_model_token_limit_unexpected_error():
|
||||||
assert token_limit is None
|
assert token_limit is None
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic(mock_model, mock_memory):
|
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||||
"""Test create_agent with Anthropic Claude model."""
|
"""Test create_agent with Anthropic Claude model."""
|
||||||
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"}
|
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
|
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
|
||||||
mock_react.return_value = "react_agent"
|
mock_react.return_value = "react_agent"
|
||||||
|
|
@ -125,9 +156,9 @@ def test_create_agent_anthropic(mock_model, mock_memory):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_openai(mock_model, mock_memory):
|
def test_create_agent_openai(mock_model, mock_config_repository):
|
||||||
"""Test create_agent with OpenAI model."""
|
"""Test create_agent with OpenAI model."""
|
||||||
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
|
mock_config_repository.update({"provider": "openai", "model": "gpt-4"})
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||||
mock_ciayn.return_value = "ciayn_agent"
|
mock_ciayn.return_value = "ciayn_agent"
|
||||||
|
|
@ -142,9 +173,9 @@ def test_create_agent_openai(mock_model, mock_memory):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_no_token_limit(mock_model, mock_memory):
|
def test_create_agent_no_token_limit(mock_model, mock_config_repository):
|
||||||
"""Test create_agent when no token limit is found."""
|
"""Test create_agent when no token limit is found."""
|
||||||
mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"}
|
mock_config_repository.update({"provider": "unknown", "model": "unknown-model"})
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||||
mock_ciayn.return_value = "ciayn_agent"
|
mock_ciayn.return_value = "ciayn_agent"
|
||||||
|
|
@ -159,9 +190,9 @@ def test_create_agent_no_token_limit(mock_model, mock_memory):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_missing_config(mock_model, mock_memory):
|
def test_create_agent_missing_config(mock_model, mock_config_repository):
|
||||||
"""Test create_agent with missing configuration."""
|
"""Test create_agent with missing configuration."""
|
||||||
mock_memory.get.return_value = {"provider": "openai"}
|
mock_config_repository.update({"provider": "openai"})
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||||
mock_ciayn.return_value = "ciayn_agent"
|
mock_ciayn.return_value = "ciayn_agent"
|
||||||
|
|
@ -205,9 +236,9 @@ def test_state_modifier(mock_messages):
|
||||||
assert result[-1] == mock_messages[-1]
|
assert result[-1] == mock_messages[-1]
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_with_checkpointer(mock_model, mock_memory):
|
def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
||||||
"""Test create_agent with checkpointer argument."""
|
"""Test create_agent with checkpointer argument."""
|
||||||
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
|
mock_config_repository.update({"provider": "openai", "model": "gpt-4"})
|
||||||
mock_checkpointer = Mock()
|
mock_checkpointer = Mock()
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||||
|
|
@ -223,13 +254,13 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
|
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
|
||||||
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
||||||
mock_memory.get.return_value = {
|
mock_config_repository.update({
|
||||||
"provider": "anthropic",
|
"provider": "anthropic",
|
||||||
"model": "claude-2",
|
"model": "claude-2",
|
||||||
"limit_tokens": True,
|
"limit_tokens": True,
|
||||||
}
|
})
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||||
|
|
@ -246,13 +277,13 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
|
||||||
assert callable(args[1]["state_modifier"])
|
assert callable(args[1]["state_modifier"])
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory):
|
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
|
||||||
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
||||||
mock_memory.get.return_value = {
|
mock_config_repository.update({
|
||||||
"provider": "anthropic",
|
"provider": "anthropic",
|
||||||
"model": "claude-2",
|
"model": "claude-2",
|
||||||
"limit_tokens": False,
|
"limit_tokens": False,
|
||||||
}
|
})
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||||
|
|
@ -267,7 +298,7 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
|
||||||
mock_react.assert_called_once_with(mock_model, [], version="v2")
|
mock_react.assert_called_once_with(mock_model, [], version="v2")
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_research(mock_memory):
|
def test_get_model_token_limit_research(mock_config_repository):
|
||||||
"""Test get_model_token_limit with research provider and model."""
|
"""Test get_model_token_limit with research provider and model."""
|
||||||
config = {
|
config = {
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
|
|
@ -275,13 +306,15 @@ def test_get_model_token_limit_research(mock_memory):
|
||||||
"research_provider": "anthropic",
|
"research_provider": "anthropic",
|
||||||
"research_model": "claude-2",
|
"research_model": "claude-2",
|
||||||
}
|
}
|
||||||
|
mock_config_repository.update(config)
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
mock_get_info.return_value = {"max_input_tokens": 150000}
|
mock_get_info.return_value = {"max_input_tokens": 150000}
|
||||||
token_limit = get_model_token_limit(config, "research")
|
token_limit = get_model_token_limit(config, "research")
|
||||||
assert token_limit == 150000
|
assert token_limit == 150000
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_token_limit_planner(mock_memory):
|
def test_get_model_token_limit_planner(mock_config_repository):
|
||||||
"""Test get_model_token_limit with planner provider and model."""
|
"""Test get_model_token_limit with planner provider and model."""
|
||||||
config = {
|
config = {
|
||||||
"provider": "openai",
|
"provider": "openai",
|
||||||
|
|
@ -289,6 +322,8 @@ def test_get_model_token_limit_planner(mock_memory):
|
||||||
"planner_provider": "deepseek",
|
"planner_provider": "deepseek",
|
||||||
"planner_model": "dsm-1",
|
"planner_model": "dsm-1",
|
||||||
}
|
}
|
||||||
|
mock_config_repository.update(config)
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||||
mock_get_info.return_value = {"max_input_tokens": 120000}
|
mock_get_info.return_value = {"max_input_tokens": 120000}
|
||||||
token_limit = get_model_token_limit(config, "planner")
|
token_limit = get_model_token_limit(config, "planner")
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,35 @@
|
||||||
"""Tests for the is_informational_query and is_stage_requested functions."""
|
"""Tests for the is_informational_query and is_stage_requested functions."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ra_aid.__main__ import is_informational_query, is_stage_requested
|
from ra_aid.__main__ import is_informational_query, is_stage_requested
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
|
||||||
|
|
||||||
|
|
||||||
def test_is_informational_query():
|
@pytest.fixture
|
||||||
|
def config_repo():
|
||||||
|
"""Fixture for config repository."""
|
||||||
|
with ConfigRepositoryManager() as repo:
|
||||||
|
yield repo
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_informational_query(config_repo):
|
||||||
"""Test that is_informational_query only depends on research_only config setting."""
|
"""Test that is_informational_query only depends on research_only config setting."""
|
||||||
# Clear global memory to ensure clean state
|
|
||||||
_global_memory.clear()
|
|
||||||
|
|
||||||
# When research_only is True, should return True
|
# When research_only is True, should return True
|
||||||
_global_memory["config"] = {"research_only": True}
|
config_repo.set("research_only", True)
|
||||||
assert is_informational_query() is True
|
assert is_informational_query() is True
|
||||||
|
|
||||||
# When research_only is False, should return False
|
# When research_only is False, should return False
|
||||||
_global_memory["config"] = {"research_only": False}
|
config_repo.set("research_only", False)
|
||||||
assert is_informational_query() is False
|
assert is_informational_query() is False
|
||||||
|
|
||||||
# When config is empty, should return False (default)
|
# When config is empty, should return False (default)
|
||||||
_global_memory.clear()
|
config_repo.update({})
|
||||||
_global_memory["config"] = {}
|
|
||||||
assert is_informational_query() is False
|
|
||||||
|
|
||||||
# When global memory is empty, should return False (default)
|
|
||||||
_global_memory.clear()
|
|
||||||
assert is_informational_query() is False
|
assert is_informational_query() is False
|
||||||
|
|
||||||
|
|
||||||
def test_is_stage_requested():
|
def test_is_stage_requested():
|
||||||
"""Test that is_stage_requested always returns False now."""
|
"""Test that is_stage_requested always returns False now."""
|
||||||
# Clear global memory to ensure clean state
|
|
||||||
_global_memory.clear()
|
|
||||||
|
|
||||||
# Should always return False regardless of input
|
# Should always return False regardless of input
|
||||||
assert is_stage_requested("implementation") is False
|
assert is_stage_requested("implementation") is False
|
||||||
assert is_stage_requested("anything_else") is False
|
assert is_stage_requested("anything_else") is False
|
||||||
|
|
||||||
# Even if we set implementation_requested in global memory
|
|
||||||
_global_memory["implementation_requested"] = True
|
|
||||||
assert is_stage_requested("implementation") is False
|
|
||||||
|
|
@ -7,14 +7,50 @@ from ra_aid.__main__ import parse_arguments
|
||||||
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
from ra_aid.database.repositories.work_log_repository import WorkLogEntry
|
from ra_aid.database.repositories.work_log_repository import WorkLogEntry
|
||||||
|
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_config_repository():
|
||||||
|
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a dictionary to simulate config
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
# Setup get method to return config values
|
||||||
|
def get_config(key, default=None):
|
||||||
|
return config.get(key, default)
|
||||||
|
mock_repo.get.side_effect = get_config
|
||||||
|
|
||||||
|
# Setup set method to update config values
|
||||||
|
def set_config(key, value):
|
||||||
|
config[key] = value
|
||||||
|
mock_repo.set.side_effect = set_config
|
||||||
|
|
||||||
|
# Setup update method to update multiple config values
|
||||||
|
def update_config(config_dict):
|
||||||
|
config.update(config_dict)
|
||||||
|
mock_repo.update.side_effect = update_config
|
||||||
|
|
||||||
|
# Setup get_all method to return the config dict
|
||||||
|
def get_all_config():
|
||||||
|
return config.copy()
|
||||||
|
mock_repo.get_all.side_effect = get_all_config
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_dependencies(monkeypatch):
|
def mock_dependencies(monkeypatch):
|
||||||
"""Mock all dependencies needed for main()."""
|
"""Mock all dependencies needed for main()."""
|
||||||
# Initialize global memory with necessary keys to prevent KeyError
|
# Initialize global memory
|
||||||
_global_memory.clear()
|
_global_memory.clear()
|
||||||
_global_memory["config"] = {}
|
|
||||||
|
|
||||||
# Mock dependencies that interact with external systems
|
# Mock dependencies that interact with external systems
|
||||||
monkeypatch.setattr("ra_aid.__main__.check_dependencies", lambda: None)
|
monkeypatch.setattr("ra_aid.__main__.check_dependencies", lambda: None)
|
||||||
|
|
@ -26,10 +62,9 @@ def mock_dependencies(monkeypatch):
|
||||||
|
|
||||||
# Mock LLM initialization
|
# Mock LLM initialization
|
||||||
def mock_config_update(*args, **kwargs):
|
def mock_config_update(*args, **kwargs):
|
||||||
config = _global_memory.get("config", {})
|
config_repo = get_config_repository()
|
||||||
if kwargs.get("temperature"):
|
if kwargs.get("temperature"):
|
||||||
config["temperature"] = kwargs["temperature"]
|
config_repo.set("temperature", kwargs["temperature"])
|
||||||
_global_memory["config"] = config
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update)
|
monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update)
|
||||||
|
|
@ -107,26 +142,52 @@ def mock_work_log_repository():
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
def test_recursion_limit_in_global_config(mock_dependencies):
|
def test_recursion_limit_in_global_config(mock_dependencies, mock_config_repository):
|
||||||
"""Test that recursion limit is correctly set in global config."""
|
"""Test that recursion limit is correctly set in global config."""
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from ra_aid.__main__ import main
|
from ra_aid.__main__ import main
|
||||||
|
|
||||||
_global_memory.clear()
|
# Clear the mock repository before each test
|
||||||
|
mock_config_repository.update.reset_mock()
|
||||||
|
|
||||||
|
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||||
|
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||||
|
# Test default recursion limit
|
||||||
|
with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]):
|
||||||
|
main()
|
||||||
|
# Check that the recursion_limit value was included in the update call
|
||||||
|
mock_config_repository.update.assert_called()
|
||||||
|
# Get the call arguments
|
||||||
|
call_args = mock_config_repository.update.call_args_list
|
||||||
|
# Find the call that includes recursion_limit
|
||||||
|
recursion_limit_found = False
|
||||||
|
for args, _ in call_args:
|
||||||
|
config_dict = args[0]
|
||||||
|
if "recursion_limit" in config_dict and config_dict["recursion_limit"] == DEFAULT_RECURSION_LIMIT:
|
||||||
|
recursion_limit_found = True
|
||||||
|
break
|
||||||
|
assert recursion_limit_found, f"recursion_limit not found in update calls: {call_args}"
|
||||||
|
|
||||||
with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]):
|
# Reset mock to clear call history
|
||||||
main()
|
mock_config_repository.update.reset_mock()
|
||||||
assert _global_memory["config"]["recursion_limit"] == DEFAULT_RECURSION_LIMIT
|
|
||||||
|
# Test custom recursion limit
|
||||||
_global_memory.clear()
|
with patch.object(sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"]):
|
||||||
|
main()
|
||||||
with patch.object(
|
# Check that the recursion_limit value was included in the update call
|
||||||
sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"]
|
mock_config_repository.update.assert_called()
|
||||||
):
|
# Get the call arguments
|
||||||
main()
|
call_args = mock_config_repository.update.call_args_list
|
||||||
assert _global_memory["config"]["recursion_limit"] == 50
|
# Find the call that includes recursion_limit with value 50
|
||||||
|
recursion_limit_found = False
|
||||||
|
for args, _ in call_args:
|
||||||
|
config_dict = args[0]
|
||||||
|
if "recursion_limit" in config_dict and config_dict["recursion_limit"] == 50:
|
||||||
|
recursion_limit_found = True
|
||||||
|
break
|
||||||
|
assert recursion_limit_found, f"recursion_limit=50 not found in update calls: {call_args}"
|
||||||
|
|
||||||
|
|
||||||
def test_negative_recursion_limit():
|
def test_negative_recursion_limit():
|
||||||
|
|
@ -141,70 +202,83 @@ def test_zero_recursion_limit():
|
||||||
parse_arguments(["-m", "test message", "--recursion-limit", "0"])
|
parse_arguments(["-m", "test message", "--recursion-limit", "0"])
|
||||||
|
|
||||||
|
|
||||||
def test_config_settings(mock_dependencies):
|
def test_config_settings(mock_dependencies, mock_config_repository):
|
||||||
"""Test that various settings are correctly applied in global config."""
|
"""Test that various settings are correctly applied in global config."""
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from ra_aid.__main__ import main
|
from ra_aid.__main__ import main
|
||||||
|
|
||||||
_global_memory.clear()
|
# Clear the mock repository before each test
|
||||||
|
mock_config_repository.update.reset_mock()
|
||||||
with patch.object(
|
|
||||||
sys,
|
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||||
"argv",
|
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||||
[
|
with patch.object(
|
||||||
"ra-aid",
|
sys,
|
||||||
"-m",
|
"argv",
|
||||||
"test message",
|
[
|
||||||
"--cowboy-mode",
|
"ra-aid",
|
||||||
"--research-only",
|
"-m",
|
||||||
"--provider",
|
"test message",
|
||||||
"anthropic",
|
"--cowboy-mode",
|
||||||
"--model",
|
"--research-only",
|
||||||
"claude-3-7-sonnet-20250219",
|
"--provider",
|
||||||
"--expert-provider",
|
"anthropic",
|
||||||
"openai",
|
"--model",
|
||||||
"--expert-model",
|
"claude-3-7-sonnet-20250219",
|
||||||
"gpt-4",
|
"--expert-provider",
|
||||||
"--temperature",
|
"openai",
|
||||||
"0.7",
|
"--expert-model",
|
||||||
"--disable-limit-tokens",
|
"gpt-4",
|
||||||
],
|
"--temperature",
|
||||||
):
|
"0.7",
|
||||||
main()
|
"--disable-limit-tokens",
|
||||||
config = _global_memory["config"]
|
],
|
||||||
assert config["cowboy_mode"] is True
|
):
|
||||||
assert config["research_only"] is True
|
main()
|
||||||
assert config["provider"] == "anthropic"
|
# Verify config values are set via the update method
|
||||||
assert config["model"] == "claude-3-7-sonnet-20250219"
|
mock_config_repository.update.assert_called()
|
||||||
assert config["expert_provider"] == "openai"
|
# Get the call arguments
|
||||||
assert config["expert_model"] == "gpt-4"
|
call_args = mock_config_repository.update.call_args_list
|
||||||
assert config["limit_tokens"] is False
|
|
||||||
|
# Check for config values in the update calls
|
||||||
|
for args, _ in call_args:
|
||||||
|
config_dict = args[0]
|
||||||
|
if "cowboy_mode" in config_dict:
|
||||||
|
assert config_dict["cowboy_mode"] is True
|
||||||
|
if "research_only" in config_dict:
|
||||||
|
assert config_dict["research_only"] is True
|
||||||
|
if "limit_tokens" in config_dict:
|
||||||
|
assert config_dict["limit_tokens"] is False
|
||||||
|
|
||||||
|
# Check provider and model settings via set method
|
||||||
|
mock_config_repository.set.assert_any_call("provider", "anthropic")
|
||||||
|
mock_config_repository.set.assert_any_call("model", "claude-3-7-sonnet-20250219")
|
||||||
|
mock_config_repository.set.assert_any_call("expert_provider", "openai")
|
||||||
|
mock_config_repository.set.assert_any_call("expert_model", "gpt-4")
|
||||||
|
|
||||||
|
|
||||||
def test_temperature_validation(mock_dependencies):
|
def test_temperature_validation(mock_dependencies, mock_config_repository):
|
||||||
"""Test that temperature argument is correctly passed to initialize_llm."""
|
"""Test that temperature argument is correctly passed to initialize_llm."""
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch, ANY
|
from unittest.mock import patch, ANY
|
||||||
|
|
||||||
from ra_aid.__main__ import main
|
from ra_aid.__main__ import main
|
||||||
|
|
||||||
# Reset global memory for clean test
|
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||||
_global_memory.clear()
|
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||||
_global_memory["config"] = {}
|
# Test valid temperature (0.7)
|
||||||
|
with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm:
|
||||||
# Test valid temperature (0.7)
|
# Also patch any calls that would actually use the mocked initialize_llm function
|
||||||
with patch("ra_aid.__main__.initialize_llm", return_value=None) as mock_init_llm:
|
with patch("ra_aid.__main__.run_research_agent", return_value=None):
|
||||||
# Also patch any calls that would actually use the mocked initialize_llm function
|
with patch("ra_aid.__main__.run_planning_agent", return_value=None):
|
||||||
with patch("ra_aid.__main__.run_research_agent", return_value=None):
|
with patch.object(
|
||||||
with patch("ra_aid.__main__.run_planning_agent", return_value=None):
|
sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"]
|
||||||
with patch.object(
|
):
|
||||||
sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"]
|
main()
|
||||||
):
|
# Verify that the temperature was set in the config repository
|
||||||
main()
|
mock_config_repository.set.assert_any_call("temperature", 0.7)
|
||||||
# Check if temperature was stored in config correctly
|
|
||||||
assert _global_memory["config"]["temperature"] == 0.7
|
|
||||||
|
|
||||||
# Test invalid temperature (2.1)
|
# Test invalid temperature (2.1)
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
|
|
@ -230,61 +304,67 @@ def test_missing_message():
|
||||||
assert args.message == "test"
|
assert args.message == "test"
|
||||||
|
|
||||||
|
|
||||||
def test_research_model_provider_args(mock_dependencies):
|
def test_research_model_provider_args(mock_dependencies, mock_config_repository):
|
||||||
"""Test that research-specific model/provider args are correctly stored in config."""
|
"""Test that research-specific model/provider args are correctly stored in config."""
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from ra_aid.__main__ import main
|
from ra_aid.__main__ import main
|
||||||
|
|
||||||
_global_memory.clear()
|
# Reset mocks
|
||||||
|
mock_config_repository.set.reset_mock()
|
||||||
with patch.object(
|
|
||||||
sys,
|
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||||
"argv",
|
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||||
[
|
with patch.object(
|
||||||
"ra-aid",
|
sys,
|
||||||
"-m",
|
"argv",
|
||||||
"test message",
|
[
|
||||||
"--research-provider",
|
"ra-aid",
|
||||||
"anthropic",
|
"-m",
|
||||||
"--research-model",
|
"test message",
|
||||||
"claude-3-haiku-20240307",
|
"--research-provider",
|
||||||
"--planner-provider",
|
"anthropic",
|
||||||
"openai",
|
"--research-model",
|
||||||
"--planner-model",
|
"claude-3-haiku-20240307",
|
||||||
"gpt-4",
|
"--planner-provider",
|
||||||
],
|
"openai",
|
||||||
):
|
"--planner-model",
|
||||||
main()
|
"gpt-4",
|
||||||
config = _global_memory["config"]
|
],
|
||||||
assert config["research_provider"] == "anthropic"
|
):
|
||||||
assert config["research_model"] == "claude-3-haiku-20240307"
|
main()
|
||||||
assert config["planner_provider"] == "openai"
|
# Verify the mock repo's set method was called with the expected values
|
||||||
assert config["planner_model"] == "gpt-4"
|
mock_config_repository.set.assert_any_call("research_provider", "anthropic")
|
||||||
|
mock_config_repository.set.assert_any_call("research_model", "claude-3-haiku-20240307")
|
||||||
|
mock_config_repository.set.assert_any_call("planner_provider", "openai")
|
||||||
|
mock_config_repository.set.assert_any_call("planner_model", "gpt-4")
|
||||||
|
|
||||||
|
|
||||||
def test_planner_model_provider_args(mock_dependencies):
|
def test_planner_model_provider_args(mock_dependencies, mock_config_repository):
|
||||||
"""Test that planner provider/model args fall back to main config when not specified."""
|
"""Test that planner provider/model args fall back to main config when not specified."""
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from ra_aid.__main__ import main
|
from ra_aid.__main__ import main
|
||||||
|
|
||||||
_global_memory.clear()
|
# Reset mocks
|
||||||
|
mock_config_repository.set.reset_mock()
|
||||||
with patch.object(
|
|
||||||
sys,
|
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||||
"argv",
|
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||||
["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"],
|
with patch.object(
|
||||||
):
|
sys,
|
||||||
main()
|
"argv",
|
||||||
config = _global_memory["config"]
|
["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"],
|
||||||
assert config["planner_provider"] == "openai"
|
):
|
||||||
assert config["planner_model"] == "gpt-4"
|
main()
|
||||||
|
# Verify the mock repo's set method was called with the expected values
|
||||||
|
mock_config_repository.set.assert_any_call("planner_provider", "openai")
|
||||||
|
mock_config_repository.set.assert_any_call("planner_model", "gpt-4")
|
||||||
|
|
||||||
|
|
||||||
def test_use_aider_flag(mock_dependencies):
|
def test_use_aider_flag(mock_dependencies, mock_config_repository):
|
||||||
"""Test that use-aider flag is correctly stored in config."""
|
"""Test that use-aider flag is correctly stored in config."""
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
@ -292,44 +372,68 @@ def test_use_aider_flag(mock_dependencies):
|
||||||
from ra_aid.__main__ import main
|
from ra_aid.__main__ import main
|
||||||
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
|
from ra_aid.tool_configs import MODIFICATION_TOOLS, set_modification_tools
|
||||||
|
|
||||||
_global_memory.clear()
|
# Reset mocks
|
||||||
|
mock_config_repository.update.reset_mock()
|
||||||
|
|
||||||
# Reset to default state
|
# Reset to default state
|
||||||
set_modification_tools(False)
|
set_modification_tools(False)
|
||||||
|
|
||||||
# Check default behavior (use_aider=False)
|
# For testing, we need to patch ConfigRepositoryManager.__enter__ to return our mock
|
||||||
with patch.object(
|
with patch('ra_aid.database.repositories.config_repository.ConfigRepositoryManager.__enter__', return_value=mock_config_repository):
|
||||||
sys,
|
# Check default behavior (use_aider=False)
|
||||||
"argv",
|
with patch.object(
|
||||||
["ra-aid", "-m", "test message"],
|
sys,
|
||||||
):
|
"argv",
|
||||||
main()
|
["ra-aid", "-m", "test message"],
|
||||||
config = _global_memory["config"]
|
):
|
||||||
assert config.get("use_aider") is False
|
main()
|
||||||
|
# Verify use_aider is set to False in the update call
|
||||||
|
mock_config_repository.update.assert_called()
|
||||||
|
# Get the call arguments
|
||||||
|
call_args = mock_config_repository.update.call_args_list
|
||||||
|
# Find the call that includes use_aider
|
||||||
|
use_aider_found = False
|
||||||
|
for args, _ in call_args:
|
||||||
|
config_dict = args[0]
|
||||||
|
if "use_aider" in config_dict and config_dict["use_aider"] is False:
|
||||||
|
use_aider_found = True
|
||||||
|
break
|
||||||
|
assert use_aider_found, f"use_aider=False not found in update calls: {call_args}"
|
||||||
|
|
||||||
# Check that file tools are enabled by default
|
# Check that file tools are enabled by default
|
||||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||||
assert "file_str_replace" in tool_names
|
assert "file_str_replace" in tool_names
|
||||||
assert "put_complete_file_contents" in tool_names
|
assert "put_complete_file_contents" in tool_names
|
||||||
assert "run_programming_task" not in tool_names
|
assert "run_programming_task" not in tool_names
|
||||||
|
|
||||||
_global_memory.clear()
|
# Reset mocks
|
||||||
|
mock_config_repository.update.reset_mock()
|
||||||
|
|
||||||
# Check with --use-aider flag
|
# Check with --use-aider flag
|
||||||
with patch.object(
|
with patch.object(
|
||||||
sys,
|
sys,
|
||||||
"argv",
|
"argv",
|
||||||
["ra-aid", "-m", "test message", "--use-aider"],
|
["ra-aid", "-m", "test message", "--use-aider"],
|
||||||
):
|
):
|
||||||
main()
|
main()
|
||||||
config = _global_memory["config"]
|
# Verify use_aider is set to True in the update call
|
||||||
assert config.get("use_aider") is True
|
mock_config_repository.update.assert_called()
|
||||||
|
# Get the call arguments
|
||||||
|
call_args = mock_config_repository.update.call_args_list
|
||||||
|
# Find the call that includes use_aider
|
||||||
|
use_aider_found = False
|
||||||
|
for args, _ in call_args:
|
||||||
|
config_dict = args[0]
|
||||||
|
if "use_aider" in config_dict and config_dict["use_aider"] is True:
|
||||||
|
use_aider_found = True
|
||||||
|
break
|
||||||
|
assert use_aider_found, f"use_aider=True not found in update calls: {call_args}"
|
||||||
|
|
||||||
# Check that run_programming_task is enabled
|
# Check that run_programming_task is enabled
|
||||||
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
tool_names = [tool.name for tool in MODIFICATION_TOOLS]
|
||||||
assert "file_str_replace" not in tool_names
|
assert "file_str_replace" not in tool_names
|
||||||
assert "put_complete_file_contents" not in tool_names
|
assert "put_complete_file_contents" not in tool_names
|
||||||
assert "run_programming_task" in tool_names
|
assert "run_programming_task" in tool_names
|
||||||
|
|
||||||
# Reset to default state for other tests
|
# Reset to default state for other tests
|
||||||
set_modification_tools(False)
|
set_modification_tools(False)
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
@ -7,6 +8,36 @@ from ra_aid.tools.programmer import (
|
||||||
run_programming_task,
|
run_programming_task,
|
||||||
)
|
)
|
||||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_config_repository():
|
||||||
|
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a dictionary to simulate config
|
||||||
|
config = {
|
||||||
|
"recursion_limit": 2,
|
||||||
|
"provider": "anthropic",
|
||||||
|
"model": "claude-3-5-sonnet-20241022",
|
||||||
|
"temperature": 0.01,
|
||||||
|
"aider_config": "/path/to/config.yml"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup get_all method to return the config dict
|
||||||
|
mock_repo.get_all.return_value = config
|
||||||
|
|
||||||
|
# Setup get method to return config values
|
||||||
|
def get_config(key, default=None):
|
||||||
|
return config.get(key, default)
|
||||||
|
mock_repo.get.side_effect = get_config
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_related_files_repository():
|
def mock_related_files_repository():
|
||||||
|
|
@ -125,13 +156,9 @@ def test_parse_aider_flags(input_flags, expected, description):
|
||||||
assert result == expected, f"Failed test case: {description}"
|
assert result == expected, f"Failed test case: {description}"
|
||||||
|
|
||||||
|
|
||||||
def test_aider_config_flag(mocker, mock_related_files_repository):
|
def test_aider_config_flag(mocker, mock_config_repository, mock_related_files_repository):
|
||||||
"""Test that aider config flag is properly included in the command when specified."""
|
"""Test that aider config flag is properly included in the command when specified."""
|
||||||
# Mock config in global memory but not related files (using repository now)
|
# Config is mocked by mock_config_repository fixture
|
||||||
mock_memory = {
|
|
||||||
"config": {"aider_config": "/path/to/config.yml"},
|
|
||||||
}
|
|
||||||
mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory)
|
|
||||||
|
|
||||||
# Mock the run_interactive_command to capture the command that would be run
|
# Mock the run_interactive_command to capture the command that would be run
|
||||||
mock_run = mocker.patch(
|
mock_run = mocker.patch(
|
||||||
|
|
@ -146,15 +173,14 @@ def test_aider_config_flag(mocker, mock_related_files_repository):
|
||||||
assert args[config_index + 1] == "/path/to/config.yml"
|
assert args[config_index + 1] == "/path/to/config.yml"
|
||||||
|
|
||||||
|
|
||||||
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_related_files_repository):
|
def test_path_normalization_and_deduplication(mocker, tmp_path, mock_config_repository, mock_related_files_repository):
|
||||||
"""Test path normalization and deduplication in run_programming_task."""
|
"""Test path normalization and deduplication in run_programming_task."""
|
||||||
# Create a temporary test file
|
# Create a temporary test file
|
||||||
test_file = tmp_path / "test.py"
|
test_file = tmp_path / "test.py"
|
||||||
test_file.write_text("")
|
test_file.write_text("")
|
||||||
new_file = tmp_path / "new.py"
|
new_file = tmp_path / "new.py"
|
||||||
|
|
||||||
# Mock dependencies - only need to mock config part of global memory now
|
# Config is mocked by mock_config_repository fixture
|
||||||
mocker.patch("ra_aid.tools.programmer._global_memory", {"config": {}})
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider"
|
"ra_aid.tools.programmer.get_aider_executable", return_value="/path/to/aider"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from ra_aid.tools.agent import (
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.tools.memory import _global_memory
|
||||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry
|
from ra_aid.database.repositories.work_log_repository import get_work_log_repository, WorkLogEntry
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -43,6 +44,34 @@ def mock_related_files_repository():
|
||||||
|
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_config_repository():
|
||||||
|
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a dictionary to simulate config
|
||||||
|
config = {
|
||||||
|
"recursion_limit": 2,
|
||||||
|
"provider": "anthropic",
|
||||||
|
"model": "claude-3-5-sonnet-20241022",
|
||||||
|
"temperature": 0.01
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup get_all method to return the config dict
|
||||||
|
mock_repo.get_all.return_value = config
|
||||||
|
|
||||||
|
# Setup get method to return config values
|
||||||
|
def get_config(key, default=None):
|
||||||
|
return config.get(key, default)
|
||||||
|
mock_repo.get.side_effect = get_config
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_work_log_repository():
|
def mock_work_log_repository():
|
||||||
"""Mock the WorkLogRepository to avoid database operations during tests"""
|
"""Mock the WorkLogRepository to avoid database operations during tests"""
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ra_aid.tools.memory import _global_memory
|
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager
|
||||||
from ra_aid.tools.shell import run_shell_command
|
from ra_aid.tools.shell import run_shell_command
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -25,9 +25,38 @@ def mock_run_interactive():
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive):
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_config_repository():
|
||||||
|
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Create a dictionary to simulate config
|
||||||
|
config = {
|
||||||
|
"cowboy_mode": False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Setup get method to return config values
|
||||||
|
def get_config(key, default=None):
|
||||||
|
return config.get(key, default)
|
||||||
|
mock_repo.get.side_effect = get_config
|
||||||
|
|
||||||
|
# Setup set method to update config values
|
||||||
|
def set_config(key, value):
|
||||||
|
config[key] = value
|
||||||
|
mock_repo.set.side_effect = set_config
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
|
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||||
"""Test shell command execution in cowboy mode (no approval)"""
|
"""Test shell command execution in cowboy mode (no approval)"""
|
||||||
_global_memory["config"] = {"cowboy_mode": True}
|
# Set cowboy mode to True using the repository
|
||||||
|
mock_config_repository.set("cowboy_mode", True)
|
||||||
|
|
||||||
result = run_shell_command.invoke({"command": "echo test"})
|
result = run_shell_command.invoke({"command": "echo test"})
|
||||||
|
|
||||||
|
|
@ -37,9 +66,10 @@ def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interacti
|
||||||
mock_prompt.ask.assert_not_called()
|
mock_prompt.ask.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive):
|
def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||||
"""Test that cowboy mode displays a properly formatted cowboy message with correct spacing"""
|
"""Test that cowboy mode displays a properly formatted cowboy message with correct spacing"""
|
||||||
_global_memory["config"] = {"cowboy_mode": True}
|
# Set cowboy mode to True using the repository
|
||||||
|
mock_config_repository.set("cowboy_mode", True)
|
||||||
|
|
||||||
with patch("ra_aid.tools.shell.get_cowboy_message") as mock_get_message:
|
with patch("ra_aid.tools.shell.get_cowboy_message") as mock_get_message:
|
||||||
mock_get_message.return_value = "🤠 Test cowboy message!"
|
mock_get_message.return_value = "🤠 Test cowboy message!"
|
||||||
|
|
@ -53,10 +83,11 @@ def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_intera
|
||||||
|
|
||||||
|
|
||||||
def test_shell_command_interactive_approved(
|
def test_shell_command_interactive_approved(
|
||||||
mock_console, mock_prompt, mock_run_interactive
|
mock_console, mock_prompt, mock_run_interactive, mock_config_repository
|
||||||
):
|
):
|
||||||
"""Test shell command execution with interactive approval"""
|
"""Test shell command execution with interactive approval"""
|
||||||
_global_memory["config"] = {"cowboy_mode": False}
|
# Set cowboy mode to False using the repository
|
||||||
|
mock_config_repository.set("cowboy_mode", False)
|
||||||
mock_prompt.ask.return_value = "y"
|
mock_prompt.ask.return_value = "y"
|
||||||
|
|
||||||
result = run_shell_command.invoke({"command": "echo test"})
|
result = run_shell_command.invoke({"command": "echo test"})
|
||||||
|
|
@ -74,10 +105,11 @@ def test_shell_command_interactive_approved(
|
||||||
|
|
||||||
|
|
||||||
def test_shell_command_interactive_rejected(
|
def test_shell_command_interactive_rejected(
|
||||||
mock_console, mock_prompt, mock_run_interactive
|
mock_console, mock_prompt, mock_run_interactive, mock_config_repository
|
||||||
):
|
):
|
||||||
"""Test shell command rejection in interactive mode"""
|
"""Test shell command rejection in interactive mode"""
|
||||||
_global_memory["config"] = {"cowboy_mode": False}
|
# Set cowboy mode to False using the repository
|
||||||
|
mock_config_repository.set("cowboy_mode", False)
|
||||||
mock_prompt.ask.return_value = "n"
|
mock_prompt.ask.return_value = "n"
|
||||||
|
|
||||||
result = run_shell_command.invoke({"command": "echo test"})
|
result = run_shell_command.invoke({"command": "echo test"})
|
||||||
|
|
@ -95,13 +127,14 @@ def test_shell_command_interactive_rejected(
|
||||||
mock_run_interactive.assert_not_called()
|
mock_run_interactive.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive):
|
def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||||
"""Test handling of shell command execution errors"""
|
"""Test handling of shell command execution errors"""
|
||||||
_global_memory["config"] = {"cowboy_mode": True}
|
# Set cowboy mode to True using the repository
|
||||||
|
mock_config_repository.set("cowboy_mode", True)
|
||||||
mock_run_interactive.side_effect = Exception("Command failed")
|
mock_run_interactive.side_effect = Exception("Command failed")
|
||||||
|
|
||||||
result = run_shell_command.invoke({"command": "invalid command"})
|
result = run_shell_command.invoke({"command": "invalid command"})
|
||||||
|
|
||||||
assert result["success"] is False
|
assert result["success"] is False
|
||||||
assert result["return_code"] == 1
|
assert result["return_code"] == 1
|
||||||
assert "Command failed" in result["output"]
|
assert "Command failed" in result["output"]
|
||||||
Loading…
Reference in New Issue