RA.Aid/ra_aid/database/repositories/config_repository.py

167 lines
5.0 KiB
Python

"""Repository for managing configuration values."""
import contextvars
from typing import Any, Dict, Optional
# Create contextvar to hold the ConfigRepository instance
config_repo_var = contextvars.ContextVar("config_repo", default=None)
class ConfigRepository:
"""
Repository for managing configuration values in memory.
This class provides methods to get, set, update, and retrieve all configuration values.
It does not require database models and operates entirely in memory.
"""
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
"""
Initialize the ConfigRepository.
Args:
initial_config: Optional dictionary of initial configuration values
"""
self._config: Dict[str, Any] = {}
# Initialize with default values from config.py
from ra_aid.config import (
DEFAULT_RECURSION_LIMIT,
DEFAULT_MAX_TEST_CMD_RETRIES,
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
DEFAULT_TEST_CMD_TIMEOUT,
DEFAULT_SHOW_COST,
VALID_PROVIDERS,
)
self._config = {
"recursion_limit": DEFAULT_RECURSION_LIMIT,
"max_test_cmd_retries": DEFAULT_MAX_TEST_CMD_RETRIES,
"max_tool_failures": DEFAULT_MAX_TOOL_FAILURES,
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
"retry_fallback_count": RETRY_FALLBACK_COUNT,
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
"show_cost": DEFAULT_SHOW_COST,
"valid_providers": VALID_PROVIDERS,
}
# Update with any provided initial configuration
if initial_config:
self._config.update(initial_config)
def get(self, key: str, default: Any = None) -> Any:
"""
Get a configuration value by key.
Args:
key: Configuration key to retrieve
default: Default value to return if key is not found
Returns:
The configuration value or default if not found
"""
return self._config.get(key, default)
def set(self, key: str, value: Any) -> None:
"""
Set a configuration value by key.
Args:
key: Configuration key to set
value: Value to set for the key
"""
self._config[key] = value
def update(self, config_dict: Dict[str, Any]) -> None:
"""
Update multiple configuration values at once.
Args:
config_dict: Dictionary of configuration key-value pairs to update
"""
self._config.update(config_dict)
def get_all(self) -> Dict[str, Any]:
"""
Get all configuration values.
Returns:
Dictionary containing all configuration values
"""
return self._config.copy()
class ConfigRepositoryManager:
"""
Context manager for ConfigRepository.
This class provides a context manager interface for ConfigRepository,
using the contextvars approach for thread safety.
Example:
with ConfigRepositoryManager() as repo:
# Use the repository
value = repo.get("key")
repo.set("key", new_value)
"""
def __init__(self, initial_config: Optional[Dict[str, Any]] = None):
"""
Initialize the ConfigRepositoryManager.
Args:
initial_config: Optional dictionary of initial configuration values
"""
self.initial_config = initial_config
def __enter__(self) -> 'ConfigRepository':
"""
Initialize the ConfigRepository and return it.
Returns:
ConfigRepository: The initialized repository
"""
repo = ConfigRepository(self.initial_config)
config_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
config_repo_var.set(None)
# Don't suppress exceptions
return False
def get_config_repository() -> ConfigRepository:
"""
Get the current ConfigRepository instance.
Returns:
ConfigRepository: The current repository instance
Raises:
RuntimeError: If no repository is set in the current context
"""
repo = config_repo_var.get()
if repo is None:
raise RuntimeError(
"ConfigRepository not initialized in current context. "
"Make sure to use ConfigRepositoryManager."
)
return repo