341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""Provider validation strategies."""
|
|
|
|
import os
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Optional
|
|
|
|
|
|
@dataclass
|
|
class ValidationResult:
|
|
"""Result of validation."""
|
|
|
|
valid: bool
|
|
missing_vars: List[str]
|
|
|
|
|
|
class ProviderStrategy(ABC):
|
|
"""Abstract base class for provider validation strategies."""
|
|
|
|
@abstractmethod
|
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
|
"""Validate provider environment variables."""
|
|
pass
|
|
|
|
|
|
class OpenAIStrategy(ProviderStrategy):
|
|
"""OpenAI provider validation strategy."""
|
|
|
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
|
"""Validate OpenAI environment variables."""
|
|
missing = []
|
|
|
|
# Check if we're validating expert config
|
|
if (
|
|
args
|
|
and hasattr(args, "expert_provider")
|
|
and args.expert_provider == "openai"
|
|
):
|
|
key = os.environ.get("EXPERT_OPENAI_API_KEY")
|
|
if not key or key == "":
|
|
# Try to copy from base if not set
|
|
base_key = os.environ.get("OPENAI_API_KEY")
|
|
if base_key:
|
|
os.environ["EXPERT_OPENAI_API_KEY"] = base_key
|
|
key = base_key
|
|
if not key:
|
|
missing.append("EXPERT_OPENAI_API_KEY environment variable is not set")
|
|
|
|
# Handle expert model selection if none specified
|
|
if hasattr(args, "expert_model") and not args.expert_model:
|
|
from ra_aid.llm import select_expert_model
|
|
|
|
model = select_expert_model("openai")
|
|
if model:
|
|
args.expert_model = model
|
|
elif hasattr(args, "research_only") and args.research_only:
|
|
missing.append("No suitable expert model available")
|
|
|
|
else:
|
|
key = os.environ.get("OPENAI_API_KEY")
|
|
if not key:
|
|
missing.append("OPENAI_API_KEY environment variable is not set")
|
|
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
|
|
class OpenAICompatibleStrategy(ProviderStrategy):
|
|
"""OpenAI-compatible provider validation strategy."""
|
|
|
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
|
"""Validate OpenAI-compatible environment variables."""
|
|
missing = []
|
|
|
|
# Check if we're validating expert config
|
|
if (
|
|
args
|
|
and hasattr(args, "expert_provider")
|
|
and args.expert_provider == "openai-compatible"
|
|
):
|
|
key = os.environ.get("EXPERT_OPENAI_API_KEY")
|
|
base = os.environ.get("EXPERT_OPENAI_API_BASE")
|
|
|
|
# Try to copy from base if not set
|
|
if not key or key == "":
|
|
base_key = os.environ.get("OPENAI_API_KEY")
|
|
if base_key:
|
|
os.environ["EXPERT_OPENAI_API_KEY"] = base_key
|
|
key = base_key
|
|
if not base or base == "":
|
|
base_base = os.environ.get("OPENAI_API_BASE")
|
|
if base_base:
|
|
os.environ["EXPERT_OPENAI_API_BASE"] = base_base
|
|
base = base_base
|
|
|
|
if not key:
|
|
missing.append("EXPERT_OPENAI_API_KEY environment variable is not set")
|
|
if not base:
|
|
missing.append("EXPERT_OPENAI_API_BASE environment variable is not set")
|
|
|
|
# Check expert model only for research-only mode
|
|
if hasattr(args, "research_only") and args.research_only:
|
|
model = args.expert_model if hasattr(args, "expert_model") else None
|
|
if not model:
|
|
model = os.environ.get("EXPERT_OPENAI_MODEL")
|
|
if not model:
|
|
model = os.environ.get("OPENAI_MODEL")
|
|
if not model:
|
|
missing.append(
|
|
"Model is required for OpenAI-compatible provider in research-only mode"
|
|
)
|
|
else:
|
|
key = os.environ.get("OPENAI_API_KEY")
|
|
base = os.environ.get("OPENAI_API_BASE")
|
|
|
|
if not key:
|
|
missing.append("OPENAI_API_KEY environment variable is not set")
|
|
if not base:
|
|
missing.append("OPENAI_API_BASE environment variable is not set")
|
|
|
|
# Check model only for research-only mode
|
|
if hasattr(args, "research_only") and args.research_only:
|
|
model = args.model if hasattr(args, "model") else None
|
|
if not model:
|
|
model = os.environ.get("OPENAI_MODEL")
|
|
if not model:
|
|
missing.append(
|
|
"Model is required for OpenAI-compatible provider in research-only mode"
|
|
)
|
|
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
|
|
class AnthropicStrategy(ProviderStrategy):
|
|
"""Anthropic provider validation strategy."""
|
|
|
|
VALID_MODELS = ["claude-"]
|
|
|
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
|
"""Validate Anthropic environment variables and model."""
|
|
missing = []
|
|
|
|
# Check if we're validating expert config
|
|
is_expert = (
|
|
args
|
|
and hasattr(args, "expert_provider")
|
|
and args.expert_provider == "anthropic"
|
|
)
|
|
|
|
# Check API key
|
|
if is_expert:
|
|
key = os.environ.get("EXPERT_ANTHROPIC_API_KEY")
|
|
if not key or key == "":
|
|
# Try to copy from base if not set
|
|
base_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if base_key:
|
|
os.environ["EXPERT_ANTHROPIC_API_KEY"] = base_key
|
|
key = base_key
|
|
if not key:
|
|
missing.append(
|
|
"EXPERT_ANTHROPIC_API_KEY environment variable is not set"
|
|
)
|
|
else:
|
|
key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if not key:
|
|
missing.append("ANTHROPIC_API_KEY environment variable is not set")
|
|
|
|
# Check model
|
|
model_matched = False
|
|
model_to_check = None
|
|
|
|
# First check command line argument
|
|
if is_expert:
|
|
if hasattr(args, "expert_model") and args.expert_model:
|
|
model_to_check = args.expert_model
|
|
else:
|
|
# If no expert model, check environment variable
|
|
model_to_check = os.environ.get("EXPERT_ANTHROPIC_MODEL")
|
|
if not model_to_check or model_to_check == "":
|
|
# Try to copy from base if not set
|
|
base_model = os.environ.get("ANTHROPIC_MODEL")
|
|
if base_model:
|
|
os.environ["EXPERT_ANTHROPIC_MODEL"] = base_model
|
|
model_to_check = base_model
|
|
else:
|
|
if hasattr(args, "model") and args.model:
|
|
model_to_check = args.model
|
|
else:
|
|
model_to_check = os.environ.get("ANTHROPIC_MODEL")
|
|
|
|
if not model_to_check:
|
|
missing.append("ANTHROPIC_MODEL environment variable is not set")
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
# Validate model format
|
|
for pattern in self.VALID_MODELS:
|
|
if re.match(pattern, model_to_check):
|
|
model_matched = True
|
|
break
|
|
|
|
if not model_matched:
|
|
missing.append(
|
|
f'Invalid Anthropic model: {model_to_check}. Must match one of these patterns: {", ".join(self.VALID_MODELS)}'
|
|
)
|
|
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
|
|
class OpenRouterStrategy(ProviderStrategy):
|
|
"""OpenRouter provider validation strategy."""
|
|
|
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
|
"""Validate OpenRouter environment variables."""
|
|
missing = []
|
|
|
|
# Check if we're validating expert config
|
|
if (
|
|
args
|
|
and hasattr(args, "expert_provider")
|
|
and args.expert_provider == "openrouter"
|
|
):
|
|
key = os.environ.get("EXPERT_OPENROUTER_API_KEY")
|
|
if not key or key == "":
|
|
# Try to copy from base if not set
|
|
base_key = os.environ.get("OPENROUTER_API_KEY")
|
|
if base_key:
|
|
os.environ["EXPERT_OPENROUTER_API_KEY"] = base_key
|
|
key = base_key
|
|
if not key:
|
|
missing.append(
|
|
"EXPERT_OPENROUTER_API_KEY environment variable is not set"
|
|
)
|
|
else:
|
|
key = os.environ.get("OPENROUTER_API_KEY")
|
|
if not key:
|
|
missing.append("OPENROUTER_API_KEY environment variable is not set")
|
|
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
|
|
class GeminiStrategy(ProviderStrategy):
|
|
"""Gemini provider validation strategy."""
|
|
|
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
|
"""Validate Gemini environment variables."""
|
|
missing = []
|
|
|
|
# Check if we're validating expert config
|
|
if (
|
|
args
|
|
and hasattr(args, "expert_provider")
|
|
and args.expert_provider == "gemini"
|
|
):
|
|
key = os.environ.get("EXPERT_GEMINI_API_KEY")
|
|
if not key or key == "":
|
|
# Try to copy from base if not set
|
|
base_key = os.environ.get("GEMINI_API_KEY")
|
|
if base_key:
|
|
os.environ["EXPERT_GEMINI_API_KEY"] = base_key
|
|
key = base_key
|
|
if not key:
|
|
missing.append("EXPERT_GEMINI_API_KEY environment variable is not set")
|
|
else:
|
|
key = os.environ.get("GEMINI_API_KEY")
|
|
if not key:
|
|
missing.append("GEMINI_API_KEY environment variable is not set")
|
|
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
|
|
class DeepSeekStrategy(ProviderStrategy):
|
|
"""DeepSeek provider validation strategy."""
|
|
|
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
|
"""Validate DeepSeek environment variables."""
|
|
missing = []
|
|
|
|
if (
|
|
args
|
|
and hasattr(args, "expert_provider")
|
|
and args.expert_provider == "deepseek"
|
|
):
|
|
key = os.environ.get("EXPERT_DEEPSEEK_API_KEY")
|
|
if not key or key == "":
|
|
# Try to copy from base if not set
|
|
base_key = os.environ.get("DEEPSEEK_API_KEY")
|
|
if base_key:
|
|
os.environ["EXPERT_DEEPSEEK_API_KEY"] = base_key
|
|
key = base_key
|
|
if not key:
|
|
missing.append(
|
|
"EXPERT_DEEPSEEK_API_KEY environment variable is not set"
|
|
)
|
|
else:
|
|
key = os.environ.get("DEEPSEEK_API_KEY")
|
|
if not key:
|
|
missing.append("DEEPSEEK_API_KEY environment variable is not set")
|
|
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
|
|
class OllamaStrategy(ProviderStrategy):
|
|
"""Ollama provider validation strategy."""
|
|
|
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
|
"""Validate Ollama environment variables."""
|
|
missing = []
|
|
|
|
base_url = os.environ.get("OLLAMA_BASE_URL")
|
|
if not base_url:
|
|
missing.append("OLLAMA_BASE_URL environment variable is not set")
|
|
|
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
|
|
|
|
|
class ProviderFactory:
|
|
"""Factory for creating provider validation strategies."""
|
|
|
|
@staticmethod
|
|
def create(provider: str, args: Optional[Any] = None) -> Optional[ProviderStrategy]:
|
|
"""Create a provider validation strategy.
|
|
|
|
Args:
|
|
provider: Provider name
|
|
args: Optional command line arguments
|
|
|
|
Returns:
|
|
Provider validation strategy or None if provider not found
|
|
"""
|
|
strategies = {
|
|
"openai": OpenAIStrategy(),
|
|
"openai-compatible": OpenAICompatibleStrategy(),
|
|
"anthropic": AnthropicStrategy(),
|
|
"openrouter": OpenRouterStrategy(),
|
|
"gemini": GeminiStrategy(),
|
|
"ollama": OllamaStrategy(),
|
|
"deepseek": DeepSeekStrategy(),
|
|
}
|
|
strategy = strategies.get(provider)
|
|
return strategy
|