212 lines
7.7 KiB
Python
212 lines
7.7 KiB
Python
"""Environment validation utilities."""
|
|
|
|
import os
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Tuple, List, Any
|
|
|
|
from ra_aid import print_error
|
|
from ra_aid.provider_strategy import ProviderFactory, ValidationResult
|
|
|
|
@dataclass
|
|
class ValidationResult:
|
|
"""Result of validation."""
|
|
valid: bool
|
|
missing_vars: List[str]
|
|
|
|
def validate_provider(provider: str) -> ValidationResult:
|
|
"""Validate provider configuration."""
|
|
if not provider:
|
|
return ValidationResult(valid=False, missing_vars=["No provider specified"])
|
|
strategy = ProviderFactory.create(provider)
|
|
if not strategy:
|
|
return ValidationResult(valid=False, missing_vars=[f"Unknown provider: {provider}"])
|
|
return strategy.validate()
|
|
|
|
def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
|
|
"""Copy base provider environment variables to expert provider if not set.
|
|
|
|
Args:
|
|
base_provider: Base provider name
|
|
expert_provider: Expert provider name
|
|
"""
|
|
# Map of base to expert environment variables for each provider
|
|
provider_vars = {
|
|
'openai': {
|
|
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
|
|
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
|
|
},
|
|
'openai-compatible': {
|
|
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
|
|
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
|
|
},
|
|
'anthropic': {
|
|
'ANTHROPIC_API_KEY': 'EXPERT_ANTHROPIC_API_KEY',
|
|
'ANTHROPIC_MODEL': 'EXPERT_ANTHROPIC_MODEL'
|
|
},
|
|
'openrouter': {
|
|
'OPENROUTER_API_KEY': 'EXPERT_OPENROUTER_API_KEY'
|
|
}
|
|
}
|
|
|
|
# Get the variables to copy based on the expert provider
|
|
vars_to_copy = provider_vars.get(expert_provider, {})
|
|
for base_var, expert_var in vars_to_copy.items():
|
|
# Only copy if expert var is not set and base var exists
|
|
if not os.environ.get(expert_var) and os.environ.get(base_var):
|
|
os.environ[expert_var] = os.environ[base_var]
|
|
|
|
def validate_expert_provider(provider: str) -> ValidationResult:
|
|
"""Validate expert provider configuration with fallback."""
|
|
if not provider:
|
|
return ValidationResult(valid=True, missing_vars=[])
|
|
|
|
strategy = ProviderFactory.create(provider)
|
|
if not strategy:
|
|
return ValidationResult(valid=False, missing_vars=[f"Unknown expert provider: {provider}"])
|
|
|
|
# Copy base vars to expert vars for fallback
|
|
copy_base_to_expert_vars(provider, provider)
|
|
|
|
# Validate expert configuration
|
|
result = strategy.validate()
|
|
missing = []
|
|
|
|
for var in result.missing_vars:
|
|
key = var.split()[0] # Get the key name without the error message
|
|
expert_key = f"EXPERT_{key}"
|
|
if not os.environ.get(expert_key):
|
|
missing.append(f"{expert_key} environment variable is not set")
|
|
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
def validate_web_research() -> ValidationResult:
|
|
"""Validate web research configuration."""
|
|
key = "TAVILY_API_KEY"
|
|
return ValidationResult(
|
|
valid=bool(os.environ.get(key)),
|
|
missing_vars=[] if os.environ.get(key) else [f"{key} environment variable is not set"]
|
|
)
|
|
|
|
def print_missing_dependencies(missing_vars: List[str]) -> None:
|
|
"""Print missing dependencies and exit."""
|
|
for var in missing_vars:
|
|
print(f"Error: {var}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
def validate_research_only_provider(args: Any) -> None:
|
|
"""Validate provider and model for research-only mode.
|
|
|
|
Args:
|
|
args: Arguments containing provider and expert provider settings
|
|
|
|
Raises:
|
|
SystemExit: If provider or model validation fails
|
|
"""
|
|
# Get provider from args
|
|
provider = args.provider if args and hasattr(args, 'provider') else None
|
|
if not provider:
|
|
sys.exit("No provider specified")
|
|
|
|
# For non-Anthropic providers in research-only mode, model must be specified
|
|
if provider != 'anthropic':
|
|
model = args.model if hasattr(args, 'model') and args.model else None
|
|
if not model:
|
|
sys.exit("Model is required for non-Anthropic providers")
|
|
|
|
def validate_research_only(args: Any) -> tuple[bool, list[str], bool, list[str]]:
|
|
"""Validate environment variables for research-only mode.
|
|
|
|
Args:
|
|
args: Arguments containing provider and expert provider settings
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- expert_enabled: Whether expert mode is enabled
|
|
- expert_missing: List of missing expert dependencies
|
|
- web_research_enabled: Whether web research is enabled
|
|
- web_research_missing: List of missing web research dependencies
|
|
"""
|
|
# Initialize results
|
|
expert_enabled = False
|
|
expert_missing = []
|
|
web_research_enabled = False
|
|
web_research_missing = []
|
|
|
|
# Validate web research dependencies
|
|
tavily_key = os.environ.get('TAVILY_API_KEY')
|
|
if not tavily_key:
|
|
web_research_missing.append('TAVILY_API_KEY environment variable is not set')
|
|
else:
|
|
web_research_enabled = True
|
|
|
|
return expert_enabled, expert_missing, web_research_enabled, web_research_missing
|
|
|
|
def validate_environment(args: Any) -> tuple[bool, list[str], bool, list[str]]:
|
|
"""Validate environment variables for providers and web research tools.
|
|
|
|
Args:
|
|
args: Arguments containing provider and expert provider settings
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- expert_enabled: Whether expert mode is enabled
|
|
- expert_missing: List of missing expert dependencies
|
|
- web_research_enabled: Whether web research is enabled
|
|
- web_research_missing: List of missing web research dependencies
|
|
"""
|
|
# For research-only mode, use separate validation
|
|
if hasattr(args, 'research_only') and args.research_only:
|
|
# Only validate provider and model when testing provider validation
|
|
if hasattr(args, 'model') and args.model is None:
|
|
validate_research_only_provider(args)
|
|
return validate_research_only(args)
|
|
|
|
# Initialize results
|
|
expert_enabled = False
|
|
expert_missing = []
|
|
web_research_enabled = False
|
|
web_research_missing = []
|
|
|
|
# Get provider from args
|
|
provider = args.provider if args and hasattr(args, 'provider') else None
|
|
if not provider:
|
|
sys.exit("No provider specified")
|
|
|
|
# Validate main provider
|
|
strategy = ProviderFactory.create(provider, args)
|
|
if not strategy:
|
|
sys.exit(f"Unknown provider: {provider}")
|
|
|
|
result = strategy.validate(args)
|
|
if not result.valid:
|
|
print_missing_dependencies(result.missing_vars)
|
|
|
|
# Handle expert provider if enabled
|
|
if args.expert_provider:
|
|
# Copy base variables to expert if not set
|
|
copy_base_to_expert_vars(provider, args.expert_provider)
|
|
|
|
# Validate expert provider
|
|
expert_strategy = ProviderFactory.create(args.expert_provider, args)
|
|
if not expert_strategy:
|
|
sys.exit(f"Unknown expert provider: {args.expert_provider}")
|
|
|
|
expert_result = expert_strategy.validate(args)
|
|
expert_missing = expert_result.missing_vars
|
|
expert_enabled = len(expert_missing) == 0
|
|
|
|
# If expert validation failed, try to copy base variables again and revalidate
|
|
if not expert_enabled:
|
|
copy_base_to_expert_vars(provider, args.expert_provider)
|
|
expert_result = expert_strategy.validate(args)
|
|
expert_missing = expert_result.missing_vars
|
|
expert_enabled = len(expert_missing) == 0
|
|
|
|
# Validate web research dependencies
|
|
web_result = validate_web_research()
|
|
web_research_enabled = web_result.valid
|
|
web_research_missing = web_result.missing_vars
|
|
|
|
return expert_enabled, expert_missing, web_research_enabled, web_research_missing
|