diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index bd0b307..cf4fdb8 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -18,6 +18,7 @@ from ra_aid.tools import ( emit_research_subtask, request_complex_implementation, read_file_tool, write_file_tool, fuzzy_find_project_files, ripgrep_search, list_directory_tree, file_str_replace, swap_task_order ) +from ra_aid.env import validate_environment from ra_aid.tools.memory import _global_memory, get_related_files, one_shot_completed from ra_aid import print_agent_output, print_stage_header, print_task_header, print_error from ra_aid.prompts import ( diff --git a/ra_aid/config.py b/ra_aid/config.py index e6d4fd3..23a3531 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -1,93 +1 @@ -"""Configuration and environment validation utilities.""" - -import os -import sys -from dataclasses import dataclass -from typing import Tuple, List - -from ra_aid import print_error - -@dataclass -class ProviderConfig: - """Configuration for a provider.""" - key_name: str - base_required: bool = False - -PROVIDER_CONFIGS = { - "anthropic": ProviderConfig("ANTHROPIC_API_KEY", base_required=True), - "openai": ProviderConfig("OPENAI_API_KEY", base_required=True), - "openrouter": ProviderConfig("OPENROUTER_API_KEY", base_required=True), - "openai-compatible": ProviderConfig("OPENAI_API_KEY", base_required=True), -} - -def validate_environment(args) -> Tuple[bool, List[str]]: - """Validate required environment variables and dependencies. - - Args: - args: The parsed command line arguments containing: - - provider: The main LLM provider - - expert_provider: The expert LLM provider - - Returns: - Tuple containing: - - bool: Whether expert mode is enabled - - List[str]: List of missing expert configuration items - - Raises: - SystemExit: If required base environment variables are missing - """ - missing = [] - provider = args.provider - expert_provider = args.expert_provider - - # Check API keys based on provider configs - if provider in PROVIDER_CONFIGS: - config = PROVIDER_CONFIGS[provider] - if config.base_required and not os.environ.get(config.key_name): - missing.append(f'{config.key_name} environment variable is not set') - - # Special case for openai-compatible needing base URL - if provider == "openai-compatible" and not os.environ.get('OPENAI_API_BASE'): - missing.append('OPENAI_API_BASE environment variable is not set') - - expert_missing = [] - if expert_provider in PROVIDER_CONFIGS: - config = PROVIDER_CONFIGS[expert_provider] - expert_key = f'EXPERT_{config.key_name}' - expert_key_missing = not os.environ.get(expert_key) - - # Try fallback to base key if providers match - fallback_available = expert_provider == provider and os.environ.get(config.key_name) - if expert_key_missing and fallback_available: - os.environ[expert_key] = os.environ[config.key_name] - expert_key_missing = False - - if expert_key_missing: - expert_missing.append(f'{expert_key} environment variable is not set') - - # Special case for openai-compatible expert needing base URL - if expert_provider == "openai-compatible": - expert_base = 'EXPERT_OPENAI_API_BASE' - base_missing = not os.environ.get(expert_base) - base_fallback = expert_provider == provider and os.environ.get('OPENAI_API_BASE') - - if base_missing and base_fallback: - os.environ[expert_base] = os.environ['OPENAI_API_BASE'] - base_missing = False - - if base_missing: - expert_missing.append(f'{expert_base} environment variable is not set') - - # If main keys missing, we must exit immediately - if missing: - print_error("Missing required dependencies:") - for item in missing: - print_error(f"- {item}") - sys.exit(1) - - # If expert keys missing, we disable expert tools instead of exiting - expert_enabled = True - if expert_missing: - expert_enabled = False - - return expert_enabled, expert_missing \ No newline at end of file +"""Configuration utilities.""" diff --git a/ra_aid/env.py b/ra_aid/env.py index e69de29..585f31f 100644 --- a/ra_aid/env.py +++ b/ra_aid/env.py @@ -0,0 +1,93 @@ +"""Environment validation utilities.""" + +import os +import sys +from dataclasses import dataclass +from typing import Tuple, List + +from ra_aid import print_error + +@dataclass +class ProviderConfig: + """Configuration for a provider.""" + key_name: str + base_required: bool = False + +PROVIDER_CONFIGS = { + "anthropic": ProviderConfig("ANTHROPIC_API_KEY", base_required=True), + "openai": ProviderConfig("OPENAI_API_KEY", base_required=True), + "openrouter": ProviderConfig("OPENROUTER_API_KEY", base_required=True), + "openai-compatible": ProviderConfig("OPENAI_API_KEY", base_required=True), +} + +def validate_environment(args) -> Tuple[bool, List[str]]: + """Validate required environment variables and dependencies. + + Args: + args: The parsed command line arguments containing: + - provider: The main LLM provider + - expert_provider: The expert LLM provider + + Returns: + Tuple containing: + - bool: Whether expert mode is enabled + - List[str]: List of missing expert configuration items + + Raises: + SystemExit: If required base environment variables are missing + """ + missing = [] + provider = args.provider + expert_provider = args.expert_provider + + # Check API keys based on provider configs + if provider in PROVIDER_CONFIGS: + config = PROVIDER_CONFIGS[provider] + if config.base_required and not os.environ.get(config.key_name): + missing.append(f'{config.key_name} environment variable is not set') + + # Special case for openai-compatible needing base URL + if provider == "openai-compatible" and not os.environ.get('OPENAI_API_BASE'): + missing.append('OPENAI_API_BASE environment variable is not set') + + expert_missing = [] + if expert_provider in PROVIDER_CONFIGS: + config = PROVIDER_CONFIGS[expert_provider] + expert_key = f'EXPERT_{config.key_name}' + expert_key_missing = not os.environ.get(expert_key) + + # Try fallback to base key if providers match + fallback_available = expert_provider == provider and os.environ.get(config.key_name) + if expert_key_missing and fallback_available: + os.environ[expert_key] = os.environ[config.key_name] + expert_key_missing = False + + if expert_key_missing: + expert_missing.append(f'{expert_key} environment variable is not set') + + # Special case for openai-compatible expert needing base URL + if expert_provider == "openai-compatible": + expert_base = 'EXPERT_OPENAI_API_BASE' + base_missing = not os.environ.get(expert_base) + base_fallback = expert_provider == provider and os.environ.get('OPENAI_API_BASE') + + if base_missing and base_fallback: + os.environ[expert_base] = os.environ['OPENAI_API_BASE'] + base_missing = False + + if base_missing: + expert_missing.append(f'{expert_base} environment variable is not set') + + # If main keys missing, we must exit immediately + if missing: + print_error("Missing required dependencies:") + for item in missing: + print_error(f"- {item}") + sys.exit(1) + + # If expert keys missing, we disable expert tools instead of exiting + expert_enabled = True + if expert_missing: + expert_enabled = False + + return expert_enabled, expert_missing diff --git a/tests/ra_aid/test_env.py b/tests/ra_aid/test_env.py index e69de29..8b29f72 100644 --- a/tests/ra_aid/test_env.py +++ b/tests/ra_aid/test_env.py @@ -0,0 +1,90 @@ +import os +import pytest +from dataclasses import dataclass +from typing import Optional + +from ra_aid.env import validate_environment + +@dataclass +class MockArgs: + provider: str + expert_provider: str + model: Optional[str] = None + expert_model: Optional[str] = None + +@pytest.fixture +def clean_env(monkeypatch): + """Remove relevant environment variables before each test""" + env_vars = [ + 'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY', + 'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY', + 'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE' + ] + for var in env_vars: + monkeypatch.delenv(var, raising=False) + +def test_anthropic_validation(clean_env, monkeypatch): + args = MockArgs(provider="anthropic", expert_provider="openai") + + # Should fail without API key + with pytest.raises(SystemExit): + validate_environment(args) + + # Should pass with API key + monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key') + expert_enabled, missing = validate_environment(args) + assert not expert_enabled + assert 'EXPERT_OPENAI_API_KEY environment variable is not set' in missing + +def test_openai_validation(clean_env, monkeypatch): + args = MockArgs(provider="openai", expert_provider="openai") + + # Should fail without API key + with pytest.raises(SystemExit): + validate_environment(args) + + # Should pass with API key and enable expert mode with fallback + monkeypatch.setenv('OPENAI_API_KEY', 'test-key') + expert_enabled, missing = validate_environment(args) + assert expert_enabled + assert not missing + assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' + +def test_openai_compatible_validation(clean_env, monkeypatch): + args = MockArgs(provider="openai-compatible", expert_provider="openai-compatible") + + # Should fail without API key and base URL + with pytest.raises(SystemExit): + validate_environment(args) + + # Should fail with only API key + monkeypatch.setenv('OPENAI_API_KEY', 'test-key') + with pytest.raises(SystemExit): + validate_environment(args) + + # Should pass with both API key and base URL + monkeypatch.setenv('OPENAI_API_BASE', 'http://test') + expert_enabled, missing = validate_environment(args) + assert expert_enabled + assert not missing + assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' + assert os.environ.get('EXPERT_OPENAI_API_BASE') == 'http://test' + +def test_expert_fallback(clean_env, monkeypatch): + args = MockArgs(provider="openai", expert_provider="openai") + + # Set only base API key + monkeypatch.setenv('OPENAI_API_KEY', 'test-key') + + # Should enable expert mode with fallback + expert_enabled, missing = validate_environment(args) + assert expert_enabled + assert not missing + assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' + + # Should use explicit expert key if available + monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key') + expert_enabled, missing = validate_environment(args) + assert expert_enabled + assert not missing + assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'expert-key'