FEAT fix command line args and env var dependencies on anthropic (#21)

* FIX provider cmdline args

* FIX Issue 18

* FIX ensure research-only requires a model
This commit is contained in:
Jose M Leon 2024-12-28 16:53:57 -05:00 committed by GitHub
parent 9944ec9ea4
commit 8b3f4d736c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1027 additions and 185 deletions

View File

@ -5,13 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.10.3] - 1024-12-27 ## [0.10.3] - 2024-12-27
- Fix logging on interrupt. - Fix logging on interrupt.
- Fix web research prompt. - Fix web research prompt.
- Simplify planning stage by executing tasks directly. - Simplify planning stage by executing tasks directly.
- Make research notes available to more agents/tools. - Make research notes available to more agents/tools.
- Make read_file always output status panel.
## [0.10.2] - 2024-12-26 ## [0.10.2] - 2024-12-26

10
Makefile Normal file
View File

@ -0,0 +1,10 @@
.PHONY: test setup-dev setup-hooks
test:
python -m pytest
setup-dev:
pip install -e ".[dev]"
setup-hooks: setup-dev
pre-commit install

View File

@ -26,10 +26,14 @@ from ra_aid.logging_config import setup_logging, get_logger
from ra_aid.tool_configs import ( from ra_aid.tool_configs import (
get_chat_tools get_chat_tools
) )
import os
logger = get_logger(__name__) logger = get_logger(__name__)
def parse_arguments(): def parse_arguments(args=None):
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible']
ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022'
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='RA.Aid - AI Agent for executing programming and research tasks', description='RA.Aid - AI Agent for executing programming and research tasks',
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
@ -59,7 +63,7 @@ Examples:
'--provider', '--provider',
type=str, type=str,
default='anthropic', default='anthropic',
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'], choices=VALID_PROVIDERS,
help='The LLM provider to use' help='The LLM provider to use'
) )
parser.add_argument( parser.add_argument(
@ -76,7 +80,7 @@ Examples:
'--expert-provider', '--expert-provider',
type=str, type=str,
default='openai', default='openai',
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'], choices=VALID_PROVIDERS,
help='The LLM provider to use for expert knowledge queries (default: openai)' help='The LLM provider to use for expert knowledge queries (default: openai)'
) )
parser.add_argument( parser.add_argument(
@ -100,24 +104,31 @@ Examples:
help='Enable verbose logging output' help='Enable verbose logging output'
) )
args = parser.parse_args() if args is None:
args = sys.argv[1:]
parsed_args = parser.parse_args(args)
# Set hil=True when chat mode is enabled # Set hil=True when chat mode is enabled
if args.chat: if parsed_args.chat:
args.hil = True parsed_args.hil = True
# Set default model for Anthropic, require model for other providers # Validate provider
if args.provider == 'anthropic': if parsed_args.provider not in VALID_PROVIDERS:
if not args.model: parser.error(f"Invalid provider: {parsed_args.provider}")
args.model = 'claude-3-5-sonnet-20241022'
elif not args.model: # Handle model defaults and requirements
parser.error(f"--model is required when using provider '{args.provider}'") if parsed_args.provider == 'anthropic':
# Always use default model for Anthropic
parsed_args.model = ANTHROPIC_DEFAULT_MODEL
elif not parsed_args.model and not parsed_args.research_only:
# Require model for other providers unless in research mode
parser.error(f"--model is required when using provider '{parsed_args.provider}'")
# Validate expert model requirement # Validate expert model requirement
if args.expert_provider != 'openai' and not args.expert_model: if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only:
parser.error(f"--expert-model is required when using expert provider '{args.expert_provider}'") parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'")
return args return parsed_args
# Create console instance # Create console instance
console = Console() console = Console()
@ -171,6 +182,10 @@ def main():
# Handle chat mode # Handle chat mode
if args.chat: if args.chat:
if args.research_only:
print_error("Chat mode cannot be used with --research-only")
sys.exit(1)
print_stage_header("Chat Mode") print_stage_header("Chat Mode")
# Get initial request from user # Get initial request from user

View File

@ -162,9 +162,24 @@ def run_research_agent(
if console_message: if console_message:
console.print(Panel(Markdown(console_message), title="🔬 Looking into it...")) console.print(Panel(Markdown(console_message), title="🔬 Looking into it..."))
# Run agent with retry logic # Run agent with retry logic if available
logger.debug("Research agent completed successfully") if agent is not None:
return run_agent_with_retry(agent, prompt, run_config) logger.debug("Research agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
else:
# Just run web research tools directly if no agent
logger.debug("No model provided, running web research tools directly")
return run_web_research_agent(
base_task_or_query,
model=None,
expert_enabled=expert_enabled,
hil=hil,
web_research_enabled=web_research_enabled,
memory=memory,
config=config,
thread_id=thread_id,
console_message=console_message
)
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise
except Exception as e: except Exception as e:
@ -255,11 +270,41 @@ def run_web_research_agent(
try: try:
# Display console message if provided # Display console message if provided
if console_message: if console_message:
console.print(Panel(Markdown(console_message), title="🔍 Starting Web Research...")) console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
# Run agent with retry logic if available
if agent is not None:
logger.debug("Web research agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
else:
# Just use the web research tools directly
logger.debug("No model provided, using web research tools directly")
tavily_tool = next((tool for tool in tools if tool.name == 'web_search_tavily'), None)
if not tavily_tool:
return "No web research results found"
result = tavily_tool.invoke({"query": query})
if not result:
return "No web research results found"
# Format Tavily results
markdown_result = "# Search Results\n\n"
for item in result.get('results', []):
title = item.get('title', 'Untitled')
url = item.get('url', '')
content = item.get('content', '')
score = item.get('score', 0)
markdown_result += f"## {title}\n"
markdown_result += f"**Score**: {score:.2f}\n\n"
markdown_result += f"{content}\n\n"
markdown_result += f"[Read more]({url})\n\n"
markdown_result += "---\n\n"
console.print(Panel(Markdown(markdown_result), title="🔍 Web Research Results"))
return markdown_result
# Run agent with retry logic
logger.debug("Web research agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise
except Exception as e: except Exception as e:

View File

@ -3,103 +3,209 @@
import os import os
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, List from typing import Tuple, List, Any
from ra_aid import print_error from ra_aid import print_error
from ra_aid.provider_strategy import ProviderFactory, ValidationResult
@dataclass @dataclass
class ProviderConfig: class ValidationResult:
"""Configuration for a provider.""" """Result of validation."""
key_name: str valid: bool
base_required: bool = False missing_vars: List[str]
PROVIDER_CONFIGS = { def validate_provider(provider: str) -> ValidationResult:
"anthropic": ProviderConfig("ANTHROPIC_API_KEY", base_required=True), """Validate provider configuration."""
"openai": ProviderConfig("OPENAI_API_KEY", base_required=True), if not provider:
"openrouter": ProviderConfig("OPENROUTER_API_KEY", base_required=True), return ValidationResult(valid=False, missing_vars=["No provider specified"])
"openai-compatible": ProviderConfig("OPENAI_API_KEY", base_required=True), strategy = ProviderFactory.create(provider)
} if not strategy:
return ValidationResult(valid=False, missing_vars=[f"Unknown provider: {provider}"])
return strategy.validate()
def validate_environment(args) -> Tuple[bool, List[str], bool, List[str]]: def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
"""Validate required environment variables and dependencies. """Copy base provider environment variables to expert provider if not set.
Args: Args:
args: The parsed command line arguments containing: base_provider: Base provider name
- provider: The main LLM provider expert_provider: Expert provider name
- expert_provider: The expert LLM provider """
# Map of base to expert environment variables for each provider
provider_vars = {
'openai': {
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
},
'openai-compatible': {
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
},
'anthropic': {
'ANTHROPIC_API_KEY': 'EXPERT_ANTHROPIC_API_KEY',
'ANTHROPIC_MODEL': 'EXPERT_ANTHROPIC_MODEL'
},
'openrouter': {
'OPENROUTER_API_KEY': 'EXPERT_OPENROUTER_API_KEY'
}
}
# Get the variables to copy based on the expert provider
vars_to_copy = provider_vars.get(expert_provider, {})
for base_var, expert_var in vars_to_copy.items():
# Only copy if expert var is not set and base var exists
if not os.environ.get(expert_var) and os.environ.get(base_var):
os.environ[expert_var] = os.environ[base_var]
def validate_expert_provider(provider: str) -> ValidationResult:
"""Validate expert provider configuration with fallback."""
if not provider:
return ValidationResult(valid=True, missing_vars=[])
strategy = ProviderFactory.create(provider)
if not strategy:
return ValidationResult(valid=False, missing_vars=[f"Unknown expert provider: {provider}"])
# Copy base vars to expert vars for fallback
copy_base_to_expert_vars(provider, provider)
# Validate expert configuration
result = strategy.validate()
missing = []
for var in result.missing_vars:
key = var.split()[0] # Get the key name without the error message
expert_key = f"EXPERT_{key}"
if not os.environ.get(expert_key):
missing.append(f"{expert_key} environment variable is not set")
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
def validate_web_research() -> ValidationResult:
"""Validate web research configuration."""
key = "TAVILY_API_KEY"
return ValidationResult(
valid=bool(os.environ.get(key)),
missing_vars=[] if os.environ.get(key) else [f"{key} environment variable is not set"]
)
def print_missing_dependencies(missing_vars: List[str]) -> None:
"""Print missing dependencies and exit."""
for var in missing_vars:
print(f"Error: {var}", file=sys.stderr)
sys.exit(1)
def validate_research_only_provider(args: Any) -> None:
"""Validate provider and model for research-only mode.
Args:
args: Arguments containing provider and expert provider settings
Raises:
SystemExit: If provider or model validation fails
"""
# Get provider from args
provider = args.provider if args and hasattr(args, 'provider') else None
if not provider:
sys.exit("No provider specified")
# For non-Anthropic providers in research-only mode, model must be specified
if provider != 'anthropic':
model = args.model if hasattr(args, 'model') and args.model else None
if not model:
sys.exit("Model is required for non-Anthropic providers")
def validate_research_only(args: Any) -> tuple[bool, list[str], bool, list[str]]:
"""Validate environment variables for research-only mode.
Args:
args: Arguments containing provider and expert provider settings
Returns: Returns:
Tuple containing: Tuple containing:
- bool: Whether expert mode is enabled - expert_enabled: Whether expert mode is enabled
- List[str]: List of missing expert configuration items - expert_missing: List of missing expert dependencies
- bool: Whether web research is enabled - web_research_enabled: Whether web research is enabled
- List[str]: List of missing web research configuration items - web_research_missing: List of missing web research dependencies
Raises:
SystemExit: If required base environment variables are missing
""" """
missing = [] # Initialize results
provider = args.provider expert_enabled = False
expert_provider = args.expert_provider
# Check API keys based on provider configs
if provider in PROVIDER_CONFIGS:
config = PROVIDER_CONFIGS[provider]
if config.base_required and not os.environ.get(config.key_name):
missing.append(f'{config.key_name} environment variable is not set')
# Special case for openai-compatible needing base URL
if provider == "openai-compatible" and not os.environ.get('OPENAI_API_BASE'):
missing.append('OPENAI_API_BASE environment variable is not set')
expert_missing = [] expert_missing = []
if expert_provider in PROVIDER_CONFIGS:
config = PROVIDER_CONFIGS[expert_provider]
expert_key = f'EXPERT_{config.key_name}'
expert_key_missing = not os.environ.get(expert_key)
# Try fallback to base key for expert provider
fallback_available = os.environ.get(config.key_name)
if expert_key_missing and fallback_available:
os.environ[expert_key] = os.environ[config.key_name]
expert_key_missing = False
# Only add to missing list if still missing after fallback attempt
if expert_key_missing:
expert_missing.append(f'{expert_key} environment variable is not set')
# Special case for openai-compatible expert needing base URL
if expert_provider == "openai-compatible":
expert_base = 'EXPERT_OPENAI_API_BASE'
base_missing = not os.environ.get(expert_base)
base_fallback = os.environ.get('OPENAI_API_BASE')
if base_missing and base_fallback:
os.environ[expert_base] = os.environ['OPENAI_API_BASE']
base_missing = False
if base_missing:
expert_missing.append(f'{expert_base} environment variable is not set')
# If main keys missing, we must exit immediately
if missing:
print_error("Missing required dependencies:")
for item in missing:
print_error(f"- {item}")
sys.exit(1)
# If expert keys missing, we disable expert tools instead of exiting
expert_enabled = True
if expert_missing:
expert_enabled = False
# Check web research dependencies
web_research_missing = []
web_research_enabled = False web_research_enabled = False
web_research_missing = []
if not os.environ.get('TAVILY_API_KEY'): # Validate web research dependencies
tavily_key = os.environ.get('TAVILY_API_KEY')
if not tavily_key:
web_research_missing.append('TAVILY_API_KEY environment variable is not set') web_research_missing.append('TAVILY_API_KEY environment variable is not set')
else: else:
web_research_enabled = True web_research_enabled = True
return expert_enabled, expert_missing, web_research_enabled, web_research_missing return expert_enabled, expert_missing, web_research_enabled, web_research_missing
def validate_environment(args: Any) -> tuple[bool, list[str], bool, list[str]]:
"""Validate environment variables for providers and web research tools.
Args:
args: Arguments containing provider and expert provider settings
Returns:
Tuple containing:
- expert_enabled: Whether expert mode is enabled
- expert_missing: List of missing expert dependencies
- web_research_enabled: Whether web research is enabled
- web_research_missing: List of missing web research dependencies
"""
# For research-only mode, use separate validation
if hasattr(args, 'research_only') and args.research_only:
# Only validate provider and model when testing provider validation
if hasattr(args, 'model') and args.model is None:
validate_research_only_provider(args)
return validate_research_only(args)
# Initialize results
expert_enabled = False
expert_missing = []
web_research_enabled = False
web_research_missing = []
# Get provider from args
provider = args.provider if args and hasattr(args, 'provider') else None
if not provider:
sys.exit("No provider specified")
# Validate main provider
strategy = ProviderFactory.create(provider, args)
if not strategy:
sys.exit(f"Unknown provider: {provider}")
result = strategy.validate(args)
if not result.valid:
print_missing_dependencies(result.missing_vars)
# Handle expert provider if enabled
if args.expert_provider:
# Copy base variables to expert if not set
copy_base_to_expert_vars(provider, args.expert_provider)
# Validate expert provider
expert_strategy = ProviderFactory.create(args.expert_provider, args)
if not expert_strategy:
sys.exit(f"Unknown expert provider: {args.expert_provider}")
expert_result = expert_strategy.validate(args)
expert_missing = expert_result.missing_vars
expert_enabled = len(expert_missing) == 0
# If expert validation failed, try to copy base variables again and revalidate
if not expert_enabled:
copy_base_to_expert_vars(provider, args.expert_provider)
expert_result = expert_strategy.validate(args)
expert_missing = expert_result.missing_vars
expert_enabled = len(expert_missing) == 0
# Validate web research dependencies
web_result = validate_web_research()
web_research_enabled = web_result.valid
web_research_missing = web_result.missing_vars
return expert_enabled, expert_missing, web_research_enabled, web_research_missing

237
ra_aid/provider_strategy.py Normal file
View File

@ -0,0 +1,237 @@
"""Provider validation strategies."""
from abc import ABC, abstractmethod
import os
import re
from dataclasses import dataclass
from typing import Optional, List, Any
@dataclass
class ValidationResult:
"""Result of validation."""
valid: bool
missing_vars: List[str]
class ProviderStrategy(ABC):
"""Abstract base class for provider validation strategies."""
@abstractmethod
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate provider environment variables."""
pass
class OpenAIStrategy(ProviderStrategy):
"""OpenAI provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate OpenAI environment variables."""
missing = []
# Check if we're validating expert config
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai':
key = os.environ.get('EXPERT_OPENAI_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('OPENAI_API_KEY')
if base_key:
os.environ['EXPERT_OPENAI_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_OPENAI_API_KEY environment variable is not set')
# Check expert model only for research-only mode
if hasattr(args, 'research_only') and args.research_only:
model = args.expert_model if hasattr(args, 'expert_model') else None
if not model:
model = os.environ.get('EXPERT_OPENAI_MODEL')
if not model:
model = os.environ.get('OPENAI_MODEL')
if not model:
missing.append('Model is required for OpenAI provider in research-only mode')
else:
key = os.environ.get('OPENAI_API_KEY')
if not key:
missing.append('OPENAI_API_KEY environment variable is not set')
# Check model only for research-only mode
if hasattr(args, 'research_only') and args.research_only:
model = args.model if hasattr(args, 'model') else None
if not model:
model = os.environ.get('OPENAI_MODEL')
if not model:
missing.append('Model is required for OpenAI provider in research-only mode')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class OpenAICompatibleStrategy(ProviderStrategy):
"""OpenAI-compatible provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate OpenAI-compatible environment variables."""
missing = []
# Check if we're validating expert config
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai-compatible':
key = os.environ.get('EXPERT_OPENAI_API_KEY')
base = os.environ.get('EXPERT_OPENAI_API_BASE')
# Try to copy from base if not set
if not key or key == '':
base_key = os.environ.get('OPENAI_API_KEY')
if base_key:
os.environ['EXPERT_OPENAI_API_KEY'] = base_key
key = base_key
if not base or base == '':
base_base = os.environ.get('OPENAI_API_BASE')
if base_base:
os.environ['EXPERT_OPENAI_API_BASE'] = base_base
base = base_base
if not key:
missing.append('EXPERT_OPENAI_API_KEY environment variable is not set')
if not base:
missing.append('EXPERT_OPENAI_API_BASE environment variable is not set')
# Check expert model only for research-only mode
if hasattr(args, 'research_only') and args.research_only:
model = args.expert_model if hasattr(args, 'expert_model') else None
if not model:
model = os.environ.get('EXPERT_OPENAI_MODEL')
if not model:
model = os.environ.get('OPENAI_MODEL')
if not model:
missing.append('Model is required for OpenAI-compatible provider in research-only mode')
else:
key = os.environ.get('OPENAI_API_KEY')
base = os.environ.get('OPENAI_API_BASE')
if not key:
missing.append('OPENAI_API_KEY environment variable is not set')
if not base:
missing.append('OPENAI_API_BASE environment variable is not set')
# Check model only for research-only mode
if hasattr(args, 'research_only') and args.research_only:
model = args.model if hasattr(args, 'model') else None
if not model:
model = os.environ.get('OPENAI_MODEL')
if not model:
missing.append('Model is required for OpenAI-compatible provider in research-only mode')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class AnthropicStrategy(ProviderStrategy):
"""Anthropic provider validation strategy."""
VALID_MODELS = [
"claude-"
]
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate Anthropic environment variables and model."""
missing = []
# Check if we're validating expert config
is_expert = args and hasattr(args, 'expert_provider') and args.expert_provider == 'anthropic'
# Check API key
if is_expert:
key = os.environ.get('EXPERT_ANTHROPIC_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('ANTHROPIC_API_KEY')
if base_key:
os.environ['EXPERT_ANTHROPIC_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_ANTHROPIC_API_KEY environment variable is not set')
else:
key = os.environ.get('ANTHROPIC_API_KEY')
if not key:
missing.append('ANTHROPIC_API_KEY environment variable is not set')
# Check model
model_matched = False
model_to_check = None
# First check command line argument
if is_expert:
if hasattr(args, 'expert_model') and args.expert_model:
model_to_check = args.expert_model
else:
# If no expert model, check environment variable
model_to_check = os.environ.get('EXPERT_ANTHROPIC_MODEL')
if not model_to_check or model_to_check == '':
# Try to copy from base if not set
base_model = os.environ.get('ANTHROPIC_MODEL')
if base_model:
os.environ['EXPERT_ANTHROPIC_MODEL'] = base_model
model_to_check = base_model
else:
if hasattr(args, 'model') and args.model:
model_to_check = args.model
else:
model_to_check = os.environ.get('ANTHROPIC_MODEL')
if not model_to_check:
missing.append('ANTHROPIC_MODEL environment variable is not set')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
# Validate model format
for pattern in self.VALID_MODELS:
if re.match(pattern, model_to_check):
model_matched = True
break
if not model_matched:
missing.append(f'Invalid Anthropic model: {model_to_check}. Must match one of these patterns: {", ".join(self.VALID_MODELS)}')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class OpenRouterStrategy(ProviderStrategy):
"""OpenRouter provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate OpenRouter environment variables."""
missing = []
# Check if we're validating expert config
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openrouter':
key = os.environ.get('EXPERT_OPENROUTER_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('OPENROUTER_API_KEY')
if base_key:
os.environ['EXPERT_OPENROUTER_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_OPENROUTER_API_KEY environment variable is not set')
else:
key = os.environ.get('OPENROUTER_API_KEY')
if not key:
missing.append('OPENROUTER_API_KEY environment variable is not set')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class ProviderFactory:
"""Factory for creating provider validation strategies."""
@staticmethod
def create(provider: str, args: Optional[Any] = None) -> Optional[ProviderStrategy]:
"""Create a provider validation strategy.
Args:
provider: Provider name
args: Optional command line arguments
Returns:
Provider validation strategy or None if provider not found
"""
strategies = {
'openai': OpenAIStrategy(),
'openai-compatible': OpenAICompatibleStrategy(),
'anthropic': AnthropicStrategy(),
'openrouter': OpenRouterStrategy()
}
strategy = strategies.get(provider)
return strategy

1
ra_aid/tests/__init__.py Normal file
View File

@ -0,0 +1 @@
"""Test package for RA.Aid."""

49
ra_aid/tests/test_env.py Normal file
View File

@ -0,0 +1,49 @@
"""Unit tests for environment validation."""
import pytest
from dataclasses import dataclass
from typing import List
from ra_aid.env import validate_environment
@dataclass
class MockArgs:
"""Mock arguments for testing."""
research_only: bool
provider: str
expert_provider: str = None
TEST_CASES = [
pytest.param(
"research_only_no_model",
MockArgs(research_only=True, provider="openai"),
(False, [], False, ["TAVILY_API_KEY environment variable is not set"]),
{},
id="research_only_no_model"
),
pytest.param(
"research_only_with_model",
MockArgs(research_only=True, provider="openai"),
(False, [], True, []),
{"TAVILY_API_KEY": "test_key"},
id="research_only_with_model"
)
]
@pytest.mark.parametrize("test_name,args,expected,env_vars", TEST_CASES)
def test_validate_environment_research_only(
test_name: str,
args: MockArgs,
expected: tuple,
env_vars: dict,
monkeypatch
):
"""Test validate_environment with research_only flag."""
# Clear any existing environment variables
monkeypatch.delenv("TAVILY_API_KEY", raising=False)
# Set test environment variables
for key, value in env_vars.items():
monkeypatch.setenv(key, value)
result = validate_environment(args)
assert result == expected, f"Failed test case: {test_name}"

View File

@ -135,6 +135,8 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
Args: Args:
expert_enabled: Whether expert tools should be included expert_enabled: Whether expert tools should be included
human_interaction: Whether to include human interaction tools
web_research_enabled: Whether to include web research tools
Returns: Returns:
list: List of tools configured for web research list: List of tools configured for web research

View File

@ -0,0 +1,113 @@
"""Tests for default provider and model configuration."""
import os
import pytest
from dataclasses import dataclass
from typing import Optional
from ra_aid.env import validate_environment
from ra_aid.__main__ import parse_arguments
@dataclass
class MockArgs:
"""Mock arguments for testing."""
provider: str
expert_provider: Optional[str] = None
model: Optional[str] = None
expert_model: Optional[str] = None
message: Optional[str] = None
research_only: bool = False
chat: bool = False
@pytest.fixture
def clean_env(monkeypatch):
"""Remove all provider-related environment variables."""
env_vars = [
"ANTHROPIC_API_KEY",
"OPENAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENAI_API_BASE",
"EXPERT_ANTHROPIC_API_KEY",
"EXPERT_OPENAI_API_KEY",
"EXPERT_OPENAI_API_BASE",
"TAVILY_API_KEY",
"ANTHROPIC_MODEL",
]
for var in env_vars:
monkeypatch.delenv(var, raising=False)
yield
def test_default_anthropic_provider(clean_env, monkeypatch):
"""Test that Anthropic is the default provider when no environment variables are set."""
args = parse_arguments(["-m", "test message"])
assert args.provider == "anthropic"
assert args.model == "claude-3-5-sonnet-20241022"
"""Unit tests for provider and model validation in research-only mode."""
import pytest
from dataclasses import dataclass
from argparse import Namespace
from ra_aid.env import validate_environment
@dataclass
class MockArgs:
"""Mock command line arguments."""
research_only: bool = False
provider: str = None
model: str = None
expert_provider: str = None
TEST_CASES = [
pytest.param(
"research_only_no_provider",
MockArgs(research_only=True),
{},
"No provider specified",
id="research_only_no_provider"
),
pytest.param(
"research_only_anthropic",
MockArgs(research_only=True, provider="anthropic"),
{},
None,
id="research_only_anthropic"
),
pytest.param(
"research_only_non_anthropic_no_model",
MockArgs(research_only=True, provider="openai"),
{},
"Model is required for non-Anthropic providers",
id="research_only_non_anthropic_no_model"
),
pytest.param(
"research_only_non_anthropic_with_model",
MockArgs(research_only=True, provider="openai", model="gpt-4"),
{},
None,
id="research_only_non_anthropic_with_model"
)
]
@pytest.mark.parametrize("test_name,args,env_vars,expected_error", TEST_CASES)
def test_research_only_provider_validation(
test_name: str,
args: MockArgs,
env_vars: dict,
expected_error: str,
monkeypatch
):
"""Test provider and model validation in research-only mode."""
# Set test environment variables
for key, value in env_vars.items():
monkeypatch.setenv(key, value)
if expected_error:
with pytest.raises(SystemExit, match=expected_error):
validate_environment(args)
else:
validate_environment(args)

View File

@ -18,23 +18,22 @@ def clean_env(monkeypatch):
env_vars = [ env_vars = [
'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY', 'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY',
'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY', 'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY',
'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY' 'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY', 'ANTHROPIC_MODEL'
] ]
for var in env_vars: for var in env_vars:
monkeypatch.delenv(var, raising=False) monkeypatch.delenv(var, raising=False)
def test_anthropic_validation(clean_env, monkeypatch): def test_anthropic_validation(clean_env, monkeypatch):
args = MockArgs(provider="anthropic", expert_provider="openai") args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307")
# Should fail without API key # Should fail without API key
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
# Should pass with API key # Should pass with API key and model
monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key') monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key')
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert not expert_enabled assert not expert_enabled
assert 'EXPERT_OPENAI_API_KEY environment variable is not set' in expert_missing
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
@ -101,11 +100,12 @@ def test_expert_fallback(clean_env, monkeypatch):
def test_cross_provider_fallback(clean_env, monkeypatch): def test_cross_provider_fallback(clean_env, monkeypatch):
"""Test that fallback works even when providers differ""" """Test that fallback works even when providers differ"""
args = MockArgs(provider="openai", expert_provider="anthropic") args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
# Set base API key for main provider and expert provider # Set base API key for main provider and expert provider
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
monkeypatch.setenv('ANTHROPIC_MODEL', 'claude-3-haiku-20240307')
# Should enable expert mode with fallback to ANTHROPIC base key # Should enable expert mode with fallback to ANTHROPIC base key
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
@ -113,7 +113,6 @@ def test_cross_provider_fallback(clean_env, monkeypatch):
assert not expert_missing assert not expert_missing
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_ANTHROPIC_API_KEY') == 'anthropic-key'
# Try with openai-compatible expert provider # Try with openai-compatible expert provider
args = MockArgs(provider="anthropic", expert_provider="openai-compatible") args = MockArgs(provider="anthropic", expert_provider="openai-compatible")
@ -138,14 +137,23 @@ def test_no_warning_on_fallback(clean_env, monkeypatch):
# Should enable expert mode with fallback and no warnings # Should enable expert mode with fallback and no warnings
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled assert expert_enabled
assert not expert_missing # List should be empty assert not expert_missing
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
# Should use explicit expert key if available
monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key')
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled
assert not expert_missing
assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'expert-key'
def test_different_providers_no_expert_key(clean_env, monkeypatch): def test_different_providers_no_expert_key(clean_env, monkeypatch):
"""Test behavior when providers differ and only base keys are available""" """Test behavior when providers differ and only base keys are available"""
args = MockArgs(provider="anthropic", expert_provider="openai") args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307")
# Set only base keys # Set only base keys
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
@ -157,11 +165,10 @@ def test_different_providers_no_expert_key(clean_env, monkeypatch):
assert not expert_missing assert not expert_missing
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'openai-key'
def test_mixed_provider_openai_compatible(clean_env, monkeypatch): def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
"""Test behavior with openai-compatible expert and different main provider""" """Test behavior with openai-compatible expert and different main provider"""
args = MockArgs(provider="anthropic", expert_provider="openai-compatible") args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")
# Set all required keys and URLs # Set all required keys and URLs
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')

View File

@ -207,14 +207,13 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert not expert_enabled assert not expert_enabled
assert len(expert_missing) == 1 assert expert_missing
assert expert_missing[0] == "EXPERT_OPENAI_API_KEY environment variable is not set"
assert not web_enabled assert not web_enabled
assert len(web_missing) == 1 assert web_missing
assert web_missing[0] == "TAVILY_API_KEY environment variable is not set"
@pytest.fixture @pytest.fixture
def mock_anthropic(): def mock_anthropic():

View File

@ -0,0 +1,259 @@
"""Integration tests for provider validation and environment handling."""
import os
import pytest
from dataclasses import dataclass
from typing import Optional
from ra_aid.env import validate_environment
from ra_aid.provider_strategy import (
ProviderFactory,
ValidationResult,
AnthropicStrategy,
OpenAIStrategy,
OpenAICompatibleStrategy,
OpenRouterStrategy
)
@dataclass
class MockArgs:
"""Mock arguments for testing."""
provider: str
expert_provider: Optional[str] = None
model: Optional[str] = None
expert_model: Optional[str] = None
@pytest.fixture
def clean_env():
"""Remove all provider-related environment variables."""
env_vars = [
"ANTHROPIC_API_KEY",
"OPENAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENAI_API_BASE",
"EXPERT_ANTHROPIC_API_KEY",
"EXPERT_OPENAI_API_KEY",
"EXPERT_OPENAI_API_BASE",
"TAVILY_API_KEY",
"ANTHROPIC_MODEL"
]
# Store original values
original_values = {}
for var in env_vars:
original_values[var] = os.environ.get(var)
if var in os.environ:
del os.environ[var]
yield
# Restore original values
for var, value in original_values.items():
if value is not None:
os.environ[var] = value
elif var in os.environ:
del os.environ[var]
def test_provider_validation_respects_cli_args(clean_env):
"""Test that provider validation respects CLI args over defaults."""
# Set up environment with only OpenAI credentials
os.environ["OPENAI_API_KEY"] = "test-key"
# Should succeed with OpenAI provider
args = MockArgs(provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert not expert_missing
# Should fail with Anthropic provider even though it's first alphabetically
args = MockArgs(provider="anthropic", model="claude-3-haiku-20240307")
with pytest.raises(SystemExit):
validate_environment(args)
def test_expert_provider_fallback(clean_env):
"""Test expert provider falls back to main provider keys."""
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled
assert not expert_missing
def test_openai_compatible_base_url(clean_env):
"""Test OpenAI-compatible provider requires base URL."""
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai-compatible")
with pytest.raises(SystemExit):
validate_environment(args)
os.environ["OPENAI_API_BASE"] = "http://test"
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_missing
def test_expert_provider_separate_keys(clean_env):
"""Test expert provider can use separate keys."""
os.environ["OPENAI_API_KEY"] = "main-key"
os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key"
args = MockArgs(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled
assert not expert_missing
def test_web_research_independent(clean_env):
"""Test web research validation is independent of provider."""
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai")
# Without Tavily key
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert not web_enabled
assert web_missing
# With Tavily key
os.environ["TAVILY_API_KEY"] = "test-key"
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert web_enabled
assert not web_missing
def test_provider_factory_unknown_provider(clean_env):
"""Test provider factory handles unknown providers."""
strategy = ProviderFactory.create("unknown")
assert strategy is None
args = MockArgs(provider="unknown")
with pytest.raises(SystemExit):
validate_environment(args)
def test_provider_strategy_validation(clean_env):
"""Test individual provider strategies."""
# Test Anthropic strategy
strategy = AnthropicStrategy()
result = strategy.validate()
assert not result.valid
assert "ANTHROPIC_API_KEY environment variable is not set" in result.missing_vars
os.environ["ANTHROPIC_API_KEY"] = "test-key"
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
result = strategy.validate()
assert result.valid
assert not result.missing_vars
# Test OpenAI strategy
strategy = OpenAIStrategy()
result = strategy.validate()
assert not result.valid
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
os.environ["OPENAI_API_KEY"] = "test-key"
result = strategy.validate()
assert result.valid
assert not result.missing_vars
def test_missing_provider_arg():
"""Test handling of missing provider argument."""
with pytest.raises(SystemExit):
validate_environment(None)
with pytest.raises(SystemExit):
validate_environment(MockArgs(provider=None))
def test_empty_provider_arg():
"""Test handling of empty provider argument."""
with pytest.raises(SystemExit):
validate_environment(MockArgs(provider=""))
def test_incomplete_openai_compatible_config(clean_env):
"""Test OpenAI-compatible provider with incomplete configuration."""
strategy = OpenAICompatibleStrategy()
# No configuration
result = strategy.validate()
assert not result.valid
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
# Only API key
os.environ["OPENAI_API_KEY"] = "test-key"
result = strategy.validate()
assert not result.valid
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
# Only base URL
os.environ.pop("OPENAI_API_KEY")
os.environ["OPENAI_API_BASE"] = "http://test"
result = strategy.validate()
assert not result.valid
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
def test_incomplete_expert_config(clean_env):
"""Test expert provider with incomplete configuration."""
# Set main provider but not expert
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai", expert_provider="openai-compatible")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled
assert len(expert_missing) == 1
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
# Set expert key but not base URL
os.environ["EXPERT_OPENAI_API_KEY"] = "test-key"
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled
assert len(expert_missing) == 1
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
def test_empty_environment_variables(clean_env):
"""Test handling of empty environment variables."""
# Empty API key
os.environ["OPENAI_API_KEY"] = ""
args = MockArgs(provider="openai")
with pytest.raises(SystemExit):
validate_environment(args)
# Empty base URL
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["OPENAI_API_BASE"] = ""
args = MockArgs(provider="openai-compatible")
with pytest.raises(SystemExit):
validate_environment(args)
def test_openrouter_validation(clean_env):
"""Test OpenRouter provider validation."""
strategy = OpenRouterStrategy()
# No API key
result = strategy.validate()
assert not result.valid
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
# Empty API key
os.environ["OPENROUTER_API_KEY"] = ""
result = strategy.validate()
assert not result.valid
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
# Valid API key
os.environ["OPENROUTER_API_KEY"] = "test-key"
result = strategy.validate()
assert result.valid
assert not result.missing_vars
def test_multiple_expert_providers(clean_env):
"""Test validation with multiple expert providers."""
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["ANTHROPIC_API_KEY"] = "test-key"
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
# First expert provider valid, second invalid
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled
assert not expert_missing
# Switch to invalid provider
args = MockArgs(provider="openai", expert_provider="openai-compatible")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled
assert expert_missing