FEAT fix command line args and env var dependencies on anthropic (#21)

* FIX provider cmdline args

* FIX Issue 18

* FIX ensure research-only requires a model
This commit is contained in:
Jose M Leon 2024-12-28 16:53:57 -05:00 committed by GitHub
parent 9944ec9ea4
commit 8b3f4d736c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1027 additions and 185 deletions

View File

@ -5,13 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.10.3] - 1024-12-27 ## [0.10.3] - 2024-12-27
- Fix logging on interrupt. - Fix logging on interrupt.
- Fix web research prompt. - Fix web research prompt.
- Simplify planning stage by executing tasks directly. - Simplify planning stage by executing tasks directly.
- Make research notes available to more agents/tools. - Make research notes available to more agents/tools.
- Make read_file always output status panel.
## [0.10.2] - 2024-12-26 ## [0.10.2] - 2024-12-26

10
Makefile Normal file
View File

@ -0,0 +1,10 @@
.PHONY: test setup-dev setup-hooks
test:
python -m pytest
setup-dev:
pip install -e ".[dev]"
setup-hooks: setup-dev
pre-commit install

View File

@ -13,11 +13,11 @@ keywords = ["langchain", "ai", "agent", "tools", "development"]
authors = [{name = "AI Christianson", email = "ai.christianson@christianson.ai"}] authors = [{name = "AI Christianson", email = "ai.christianson@christianson.ai"}]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Topic :: Software Development :: Libraries :: Python Modules" "Topic :: Software Development :: Libraries :: Python Modules"
] ]
@ -31,7 +31,7 @@ dependencies = [
"langchain>=0.3.13", "langchain>=0.3.13",
"rich>=13.0.0", "rich>=13.0.0",
"GitPython>=3.1", "GitPython>=3.1",
"fuzzywuzzy==0.18.0", "fuzzywuzzy==0.18.0",
"python-Levenshtein==0.23.0", "python-Levenshtein==0.23.0",
"pathspec>=0.11.0", "pathspec>=0.11.0",
"aider-chat>=0.69.1", "aider-chat>=0.69.1",

View File

@ -26,10 +26,14 @@ from ra_aid.logging_config import setup_logging, get_logger
from ra_aid.tool_configs import ( from ra_aid.tool_configs import (
get_chat_tools get_chat_tools
) )
import os
logger = get_logger(__name__) logger = get_logger(__name__)
def parse_arguments(): def parse_arguments(args=None):
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible']
ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022'
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='RA.Aid - AI Agent for executing programming and research tasks', description='RA.Aid - AI Agent for executing programming and research tasks',
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
@ -59,7 +63,7 @@ Examples:
'--provider', '--provider',
type=str, type=str,
default='anthropic', default='anthropic',
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'], choices=VALID_PROVIDERS,
help='The LLM provider to use' help='The LLM provider to use'
) )
parser.add_argument( parser.add_argument(
@ -76,7 +80,7 @@ Examples:
'--expert-provider', '--expert-provider',
type=str, type=str,
default='openai', default='openai',
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'], choices=VALID_PROVIDERS,
help='The LLM provider to use for expert knowledge queries (default: openai)' help='The LLM provider to use for expert knowledge queries (default: openai)'
) )
parser.add_argument( parser.add_argument(
@ -99,25 +103,32 @@ Examples:
action='store_true', action='store_true',
help='Enable verbose logging output' help='Enable verbose logging output'
) )
args = parser.parse_args() if args is None:
args = sys.argv[1:]
parsed_args = parser.parse_args(args)
# Set hil=True when chat mode is enabled # Set hil=True when chat mode is enabled
if args.chat: if parsed_args.chat:
args.hil = True parsed_args.hil = True
# Set default model for Anthropic, require model for other providers # Validate provider
if args.provider == 'anthropic': if parsed_args.provider not in VALID_PROVIDERS:
if not args.model: parser.error(f"Invalid provider: {parsed_args.provider}")
args.model = 'claude-3-5-sonnet-20241022'
elif not args.model: # Handle model defaults and requirements
parser.error(f"--model is required when using provider '{args.provider}'") if parsed_args.provider == 'anthropic':
# Always use default model for Anthropic
parsed_args.model = ANTHROPIC_DEFAULT_MODEL
elif not parsed_args.model and not parsed_args.research_only:
# Require model for other providers unless in research mode
parser.error(f"--model is required when using provider '{parsed_args.provider}'")
# Validate expert model requirement # Validate expert model requirement
if args.expert_provider != 'openai' and not args.expert_model: if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only:
parser.error(f"--expert-model is required when using expert provider '{args.expert_provider}'") parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'")
return args return parsed_args
# Create console instance # Create console instance
console = Console() console = Console()
@ -143,14 +154,14 @@ def main():
args = parse_arguments() args = parse_arguments()
setup_logging(args.verbose) setup_logging(args.verbose)
logger.debug("Starting RA.Aid with arguments: %s", args) logger.debug("Starting RA.Aid with arguments: %s", args)
try: try:
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing
logger.debug("Environment validation successful") logger.debug("Environment validation successful")
if expert_missing: if expert_missing:
console.print(Panel( console.print(Panel(
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" + f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" +
"\n".join(f"- {m}" for m in expert_missing) + "\n".join(f"- {m}" for m in expert_missing) +
"\nSet the required environment variables or args to enable expert mode.", "\nSet the required environment variables or args to enable expert mode.",
title="Expert Tools Disabled", title="Expert Tools Disabled",
@ -159,20 +170,24 @@ def main():
if web_research_missing: if web_research_missing:
console.print(Panel( console.print(Panel(
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" + f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" +
"\n".join(f"- {m}" for m in web_research_missing) + "\n".join(f"- {m}" for m in web_research_missing) +
"\nSet the required environment variables to enable web research.", "\nSet the required environment variables to enable web research.",
title="Web Research Disabled", title="Web Research Disabled",
style="yellow" style="yellow"
)) ))
# Create the base model after validation # Create the base model after validation
model = initialize_llm(args.provider, args.model) model = initialize_llm(args.provider, args.model)
# Handle chat mode # Handle chat mode
if args.chat: if args.chat:
if args.research_only:
print_error("Chat mode cannot be used with --research-only")
sys.exit(1)
print_stage_header("Chat Mode") print_stage_header("Chat Mode")
# Get initial request from user # Get initial request from user
initial_request = ask_human.invoke({"question": "What would you like help with?"}) initial_request = ask_human.invoke({"question": "What would you like help with?"})
@ -182,7 +197,7 @@ def main():
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled), get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
checkpointer=MemorySaver() checkpointer=MemorySaver()
) )
# Run chat agent with CHAT_PROMPT # Run chat agent with CHAT_PROMPT
config = { config = {
"configurable": {"thread_id": uuid.uuid4()}, "configurable": {"thread_id": uuid.uuid4()},
@ -193,14 +208,14 @@ def main():
"web_research_enabled": web_research_enabled, "web_research_enabled": web_research_enabled,
"initial_request": initial_request "initial_request": initial_request
} }
# Store config in global memory # Store config in global memory
_global_memory['config'] = config _global_memory['config'] = config
_global_memory['config']['provider'] = args.provider _global_memory['config']['provider'] = args.provider
_global_memory['config']['model'] = args.model _global_memory['config']['model'] = args.model
_global_memory['config']['expert_provider'] = args.expert_provider _global_memory['config']['expert_provider'] = args.expert_provider
_global_memory['config']['expert_model'] = args.expert_model _global_memory['config']['expert_model'] = args.expert_model
# Run chat agent and exit # Run chat agent and exit
run_agent_with_retry(chat_agent, CHAT_PROMPT.format( run_agent_with_retry(chat_agent, CHAT_PROMPT.format(
initial_request=initial_request, initial_request=initial_request,
@ -212,7 +227,7 @@ def main():
if not args.message: if not args.message:
print_error("--message is required") print_error("--message is required")
sys.exit(1) sys.exit(1)
base_task = args.message base_task = args.message
config = { config = {
"configurable": {"thread_id": uuid.uuid4()}, "configurable": {"thread_id": uuid.uuid4()},
@ -221,21 +236,21 @@ def main():
"cowboy_mode": args.cowboy_mode, "cowboy_mode": args.cowboy_mode,
"web_research_enabled": web_research_enabled "web_research_enabled": web_research_enabled
} }
# Store config in global memory for access by is_informational_query # Store config in global memory for access by is_informational_query
_global_memory['config'] = config _global_memory['config'] = config
# Store model configuration # Store model configuration
_global_memory['config']['provider'] = args.provider _global_memory['config']['provider'] = args.provider
_global_memory['config']['model'] = args.model _global_memory['config']['model'] = args.model
# Store expert provider and model in config # Store expert provider and model in config
_global_memory['config']['expert_provider'] = args.expert_provider _global_memory['config']['expert_provider'] = args.expert_provider
_global_memory['config']['expert_model'] = args.expert_model _global_memory['config']['expert_model'] = args.expert_model
# Run research stage # Run research stage
print_stage_header("Research Stage") print_stage_header("Research Stage")
run_research_agent( run_research_agent(
base_task, base_task,
model, model,
@ -245,7 +260,7 @@ def main():
memory=research_memory, memory=research_memory,
config=config config=config
) )
# Proceed with planning and implementation if not an informational query # Proceed with planning and implementation if not an informational query
if not is_informational_query(): if not is_informational_query():
# Run planning agent # Run planning agent

View File

@ -162,9 +162,24 @@ def run_research_agent(
if console_message: if console_message:
console.print(Panel(Markdown(console_message), title="🔬 Looking into it...")) console.print(Panel(Markdown(console_message), title="🔬 Looking into it..."))
# Run agent with retry logic # Run agent with retry logic if available
logger.debug("Research agent completed successfully") if agent is not None:
return run_agent_with_retry(agent, prompt, run_config) logger.debug("Research agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
else:
# Just run web research tools directly if no agent
logger.debug("No model provided, running web research tools directly")
return run_web_research_agent(
base_task_or_query,
model=None,
expert_enabled=expert_enabled,
hil=hil,
web_research_enabled=web_research_enabled,
memory=memory,
config=config,
thread_id=thread_id,
console_message=console_message
)
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise
except Exception as e: except Exception as e:
@ -255,11 +270,41 @@ def run_web_research_agent(
try: try:
# Display console message if provided # Display console message if provided
if console_message: if console_message:
console.print(Panel(Markdown(console_message), title="🔍 Starting Web Research...")) console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
# Run agent with retry logic if available
if agent is not None:
logger.debug("Web research agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
else:
# Just use the web research tools directly
logger.debug("No model provided, using web research tools directly")
tavily_tool = next((tool for tool in tools if tool.name == 'web_search_tavily'), None)
if not tavily_tool:
return "No web research results found"
result = tavily_tool.invoke({"query": query})
if not result:
return "No web research results found"
# Format Tavily results
markdown_result = "# Search Results\n\n"
for item in result.get('results', []):
title = item.get('title', 'Untitled')
url = item.get('url', '')
content = item.get('content', '')
score = item.get('score', 0)
markdown_result += f"## {title}\n"
markdown_result += f"**Score**: {score:.2f}\n\n"
markdown_result += f"{content}\n\n"
markdown_result += f"[Read more]({url})\n\n"
markdown_result += "---\n\n"
console.print(Panel(Markdown(markdown_result), title="🔍 Web Research Results"))
return markdown_result
# Run agent with retry logic
logger.debug("Web research agent completed successfully")
return run_agent_with_retry(agent, prompt, run_config)
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise
except Exception as e: except Exception as e:

View File

@ -3,103 +3,209 @@
import os import os
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, List from typing import Tuple, List, Any
from ra_aid import print_error from ra_aid import print_error
from ra_aid.provider_strategy import ProviderFactory, ValidationResult
@dataclass @dataclass
class ProviderConfig: class ValidationResult:
"""Configuration for a provider.""" """Result of validation."""
key_name: str valid: bool
base_required: bool = False missing_vars: List[str]
PROVIDER_CONFIGS = { def validate_provider(provider: str) -> ValidationResult:
"anthropic": ProviderConfig("ANTHROPIC_API_KEY", base_required=True), """Validate provider configuration."""
"openai": ProviderConfig("OPENAI_API_KEY", base_required=True), if not provider:
"openrouter": ProviderConfig("OPENROUTER_API_KEY", base_required=True), return ValidationResult(valid=False, missing_vars=["No provider specified"])
"openai-compatible": ProviderConfig("OPENAI_API_KEY", base_required=True), strategy = ProviderFactory.create(provider)
} if not strategy:
return ValidationResult(valid=False, missing_vars=[f"Unknown provider: {provider}"])
return strategy.validate()
def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
"""Copy base provider environment variables to expert provider if not set.
def validate_environment(args) -> Tuple[bool, List[str], bool, List[str]]:
"""Validate required environment variables and dependencies.
Args: Args:
args: The parsed command line arguments containing: base_provider: Base provider name
- provider: The main LLM provider expert_provider: Expert provider name
- expert_provider: The expert LLM provider """
# Map of base to expert environment variables for each provider
provider_vars = {
'openai': {
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
},
'openai-compatible': {
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
},
'anthropic': {
'ANTHROPIC_API_KEY': 'EXPERT_ANTHROPIC_API_KEY',
'ANTHROPIC_MODEL': 'EXPERT_ANTHROPIC_MODEL'
},
'openrouter': {
'OPENROUTER_API_KEY': 'EXPERT_OPENROUTER_API_KEY'
}
}
# Get the variables to copy based on the expert provider
vars_to_copy = provider_vars.get(expert_provider, {})
for base_var, expert_var in vars_to_copy.items():
# Only copy if expert var is not set and base var exists
if not os.environ.get(expert_var) and os.environ.get(base_var):
os.environ[expert_var] = os.environ[base_var]
def validate_expert_provider(provider: str) -> ValidationResult:
"""Validate expert provider configuration with fallback."""
if not provider:
return ValidationResult(valid=True, missing_vars=[])
strategy = ProviderFactory.create(provider)
if not strategy:
return ValidationResult(valid=False, missing_vars=[f"Unknown expert provider: {provider}"])
# Copy base vars to expert vars for fallback
copy_base_to_expert_vars(provider, provider)
# Validate expert configuration
result = strategy.validate()
missing = []
for var in result.missing_vars:
key = var.split()[0] # Get the key name without the error message
expert_key = f"EXPERT_{key}"
if not os.environ.get(expert_key):
missing.append(f"{expert_key} environment variable is not set")
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
def validate_web_research() -> ValidationResult:
"""Validate web research configuration."""
key = "TAVILY_API_KEY"
return ValidationResult(
valid=bool(os.environ.get(key)),
missing_vars=[] if os.environ.get(key) else [f"{key} environment variable is not set"]
)
def print_missing_dependencies(missing_vars: List[str]) -> None:
"""Print missing dependencies and exit."""
for var in missing_vars:
print(f"Error: {var}", file=sys.stderr)
sys.exit(1)
def validate_research_only_provider(args: Any) -> None:
"""Validate provider and model for research-only mode.
Args:
args: Arguments containing provider and expert provider settings
Raises:
SystemExit: If provider or model validation fails
"""
# Get provider from args
provider = args.provider if args and hasattr(args, 'provider') else None
if not provider:
sys.exit("No provider specified")
# For non-Anthropic providers in research-only mode, model must be specified
if provider != 'anthropic':
model = args.model if hasattr(args, 'model') and args.model else None
if not model:
sys.exit("Model is required for non-Anthropic providers")
def validate_research_only(args: Any) -> tuple[bool, list[str], bool, list[str]]:
"""Validate environment variables for research-only mode.
Args:
args: Arguments containing provider and expert provider settings
Returns: Returns:
Tuple containing: Tuple containing:
- bool: Whether expert mode is enabled - expert_enabled: Whether expert mode is enabled
- List[str]: List of missing expert configuration items - expert_missing: List of missing expert dependencies
- bool: Whether web research is enabled - web_research_enabled: Whether web research is enabled
- List[str]: List of missing web research configuration items - web_research_missing: List of missing web research dependencies
Raises:
SystemExit: If required base environment variables are missing
""" """
missing = [] # Initialize results
provider = args.provider expert_enabled = False
expert_provider = args.expert_provider
# Check API keys based on provider configs
if provider in PROVIDER_CONFIGS:
config = PROVIDER_CONFIGS[provider]
if config.base_required and not os.environ.get(config.key_name):
missing.append(f'{config.key_name} environment variable is not set')
# Special case for openai-compatible needing base URL
if provider == "openai-compatible" and not os.environ.get('OPENAI_API_BASE'):
missing.append('OPENAI_API_BASE environment variable is not set')
expert_missing = [] expert_missing = []
if expert_provider in PROVIDER_CONFIGS:
config = PROVIDER_CONFIGS[expert_provider]
expert_key = f'EXPERT_{config.key_name}'
expert_key_missing = not os.environ.get(expert_key)
# Try fallback to base key for expert provider
fallback_available = os.environ.get(config.key_name)
if expert_key_missing and fallback_available:
os.environ[expert_key] = os.environ[config.key_name]
expert_key_missing = False
# Only add to missing list if still missing after fallback attempt
if expert_key_missing:
expert_missing.append(f'{expert_key} environment variable is not set')
# Special case for openai-compatible expert needing base URL
if expert_provider == "openai-compatible":
expert_base = 'EXPERT_OPENAI_API_BASE'
base_missing = not os.environ.get(expert_base)
base_fallback = os.environ.get('OPENAI_API_BASE')
if base_missing and base_fallback:
os.environ[expert_base] = os.environ['OPENAI_API_BASE']
base_missing = False
if base_missing:
expert_missing.append(f'{expert_base} environment variable is not set')
# If main keys missing, we must exit immediately
if missing:
print_error("Missing required dependencies:")
for item in missing:
print_error(f"- {item}")
sys.exit(1)
# If expert keys missing, we disable expert tools instead of exiting
expert_enabled = True
if expert_missing:
expert_enabled = False
# Check web research dependencies
web_research_missing = []
web_research_enabled = False web_research_enabled = False
web_research_missing = []
if not os.environ.get('TAVILY_API_KEY'):
# Validate web research dependencies
tavily_key = os.environ.get('TAVILY_API_KEY')
if not tavily_key:
web_research_missing.append('TAVILY_API_KEY environment variable is not set') web_research_missing.append('TAVILY_API_KEY environment variable is not set')
else: else:
web_research_enabled = True web_research_enabled = True
return expert_enabled, expert_missing, web_research_enabled, web_research_missing return expert_enabled, expert_missing, web_research_enabled, web_research_missing
def validate_environment(args: Any) -> tuple[bool, list[str], bool, list[str]]:
"""Validate environment variables for providers and web research tools.
Args:
args: Arguments containing provider and expert provider settings
Returns:
Tuple containing:
- expert_enabled: Whether expert mode is enabled
- expert_missing: List of missing expert dependencies
- web_research_enabled: Whether web research is enabled
- web_research_missing: List of missing web research dependencies
"""
# For research-only mode, use separate validation
if hasattr(args, 'research_only') and args.research_only:
# Only validate provider and model when testing provider validation
if hasattr(args, 'model') and args.model is None:
validate_research_only_provider(args)
return validate_research_only(args)
# Initialize results
expert_enabled = False
expert_missing = []
web_research_enabled = False
web_research_missing = []
# Get provider from args
provider = args.provider if args and hasattr(args, 'provider') else None
if not provider:
sys.exit("No provider specified")
# Validate main provider
strategy = ProviderFactory.create(provider, args)
if not strategy:
sys.exit(f"Unknown provider: {provider}")
result = strategy.validate(args)
if not result.valid:
print_missing_dependencies(result.missing_vars)
# Handle expert provider if enabled
if args.expert_provider:
# Copy base variables to expert if not set
copy_base_to_expert_vars(provider, args.expert_provider)
# Validate expert provider
expert_strategy = ProviderFactory.create(args.expert_provider, args)
if not expert_strategy:
sys.exit(f"Unknown expert provider: {args.expert_provider}")
expert_result = expert_strategy.validate(args)
expert_missing = expert_result.missing_vars
expert_enabled = len(expert_missing) == 0
# If expert validation failed, try to copy base variables again and revalidate
if not expert_enabled:
copy_base_to_expert_vars(provider, args.expert_provider)
expert_result = expert_strategy.validate(args)
expert_missing = expert_result.missing_vars
expert_enabled = len(expert_missing) == 0
# Validate web research dependencies
web_result = validate_web_research()
web_research_enabled = web_result.valid
web_research_missing = web_result.missing_vars
return expert_enabled, expert_missing, web_research_enabled, web_research_missing

237
ra_aid/provider_strategy.py Normal file
View File

@ -0,0 +1,237 @@
"""Provider validation strategies."""
from abc import ABC, abstractmethod
import os
import re
from dataclasses import dataclass
from typing import Optional, List, Any
@dataclass
class ValidationResult:
"""Result of validation."""
valid: bool
missing_vars: List[str]
class ProviderStrategy(ABC):
"""Abstract base class for provider validation strategies."""
@abstractmethod
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate provider environment variables."""
pass
class OpenAIStrategy(ProviderStrategy):
"""OpenAI provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate OpenAI environment variables."""
missing = []
# Check if we're validating expert config
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai':
key = os.environ.get('EXPERT_OPENAI_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('OPENAI_API_KEY')
if base_key:
os.environ['EXPERT_OPENAI_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_OPENAI_API_KEY environment variable is not set')
# Check expert model only for research-only mode
if hasattr(args, 'research_only') and args.research_only:
model = args.expert_model if hasattr(args, 'expert_model') else None
if not model:
model = os.environ.get('EXPERT_OPENAI_MODEL')
if not model:
model = os.environ.get('OPENAI_MODEL')
if not model:
missing.append('Model is required for OpenAI provider in research-only mode')
else:
key = os.environ.get('OPENAI_API_KEY')
if not key:
missing.append('OPENAI_API_KEY environment variable is not set')
# Check model only for research-only mode
if hasattr(args, 'research_only') and args.research_only:
model = args.model if hasattr(args, 'model') else None
if not model:
model = os.environ.get('OPENAI_MODEL')
if not model:
missing.append('Model is required for OpenAI provider in research-only mode')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class OpenAICompatibleStrategy(ProviderStrategy):
"""OpenAI-compatible provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate OpenAI-compatible environment variables."""
missing = []
# Check if we're validating expert config
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai-compatible':
key = os.environ.get('EXPERT_OPENAI_API_KEY')
base = os.environ.get('EXPERT_OPENAI_API_BASE')
# Try to copy from base if not set
if not key or key == '':
base_key = os.environ.get('OPENAI_API_KEY')
if base_key:
os.environ['EXPERT_OPENAI_API_KEY'] = base_key
key = base_key
if not base or base == '':
base_base = os.environ.get('OPENAI_API_BASE')
if base_base:
os.environ['EXPERT_OPENAI_API_BASE'] = base_base
base = base_base
if not key:
missing.append('EXPERT_OPENAI_API_KEY environment variable is not set')
if not base:
missing.append('EXPERT_OPENAI_API_BASE environment variable is not set')
# Check expert model only for research-only mode
if hasattr(args, 'research_only') and args.research_only:
model = args.expert_model if hasattr(args, 'expert_model') else None
if not model:
model = os.environ.get('EXPERT_OPENAI_MODEL')
if not model:
model = os.environ.get('OPENAI_MODEL')
if not model:
missing.append('Model is required for OpenAI-compatible provider in research-only mode')
else:
key = os.environ.get('OPENAI_API_KEY')
base = os.environ.get('OPENAI_API_BASE')
if not key:
missing.append('OPENAI_API_KEY environment variable is not set')
if not base:
missing.append('OPENAI_API_BASE environment variable is not set')
# Check model only for research-only mode
if hasattr(args, 'research_only') and args.research_only:
model = args.model if hasattr(args, 'model') else None
if not model:
model = os.environ.get('OPENAI_MODEL')
if not model:
missing.append('Model is required for OpenAI-compatible provider in research-only mode')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class AnthropicStrategy(ProviderStrategy):
"""Anthropic provider validation strategy."""
VALID_MODELS = [
"claude-"
]
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate Anthropic environment variables and model."""
missing = []
# Check if we're validating expert config
is_expert = args and hasattr(args, 'expert_provider') and args.expert_provider == 'anthropic'
# Check API key
if is_expert:
key = os.environ.get('EXPERT_ANTHROPIC_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('ANTHROPIC_API_KEY')
if base_key:
os.environ['EXPERT_ANTHROPIC_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_ANTHROPIC_API_KEY environment variable is not set')
else:
key = os.environ.get('ANTHROPIC_API_KEY')
if not key:
missing.append('ANTHROPIC_API_KEY environment variable is not set')
# Check model
model_matched = False
model_to_check = None
# First check command line argument
if is_expert:
if hasattr(args, 'expert_model') and args.expert_model:
model_to_check = args.expert_model
else:
# If no expert model, check environment variable
model_to_check = os.environ.get('EXPERT_ANTHROPIC_MODEL')
if not model_to_check or model_to_check == '':
# Try to copy from base if not set
base_model = os.environ.get('ANTHROPIC_MODEL')
if base_model:
os.environ['EXPERT_ANTHROPIC_MODEL'] = base_model
model_to_check = base_model
else:
if hasattr(args, 'model') and args.model:
model_to_check = args.model
else:
model_to_check = os.environ.get('ANTHROPIC_MODEL')
if not model_to_check:
missing.append('ANTHROPIC_MODEL environment variable is not set')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
# Validate model format
for pattern in self.VALID_MODELS:
if re.match(pattern, model_to_check):
model_matched = True
break
if not model_matched:
missing.append(f'Invalid Anthropic model: {model_to_check}. Must match one of these patterns: {", ".join(self.VALID_MODELS)}')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class OpenRouterStrategy(ProviderStrategy):
"""OpenRouter provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate OpenRouter environment variables."""
missing = []
# Check if we're validating expert config
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openrouter':
key = os.environ.get('EXPERT_OPENROUTER_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('OPENROUTER_API_KEY')
if base_key:
os.environ['EXPERT_OPENROUTER_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_OPENROUTER_API_KEY environment variable is not set')
else:
key = os.environ.get('OPENROUTER_API_KEY')
if not key:
missing.append('OPENROUTER_API_KEY environment variable is not set')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class ProviderFactory:
"""Factory for creating provider validation strategies."""
@staticmethod
def create(provider: str, args: Optional[Any] = None) -> Optional[ProviderStrategy]:
"""Create a provider validation strategy.
Args:
provider: Provider name
args: Optional command line arguments
Returns:
Provider validation strategy or None if provider not found
"""
strategies = {
'openai': OpenAIStrategy(),
'openai-compatible': OpenAICompatibleStrategy(),
'anthropic': AnthropicStrategy(),
'openrouter': OpenRouterStrategy()
}
strategy = strategies.get(provider)
return strategy

1
ra_aid/tests/__init__.py Normal file
View File

@ -0,0 +1 @@
"""Test package for RA.Aid."""

49
ra_aid/tests/test_env.py Normal file
View File

@ -0,0 +1,49 @@
"""Unit tests for environment validation."""
import pytest
from dataclasses import dataclass
from typing import List
from ra_aid.env import validate_environment
@dataclass
class MockArgs:
"""Mock arguments for testing."""
research_only: bool
provider: str
expert_provider: str = None
TEST_CASES = [
pytest.param(
"research_only_no_model",
MockArgs(research_only=True, provider="openai"),
(False, [], False, ["TAVILY_API_KEY environment variable is not set"]),
{},
id="research_only_no_model"
),
pytest.param(
"research_only_with_model",
MockArgs(research_only=True, provider="openai"),
(False, [], True, []),
{"TAVILY_API_KEY": "test_key"},
id="research_only_with_model"
)
]
@pytest.mark.parametrize("test_name,args,expected,env_vars", TEST_CASES)
def test_validate_environment_research_only(
test_name: str,
args: MockArgs,
expected: tuple,
env_vars: dict,
monkeypatch
):
"""Test validate_environment with research_only flag."""
# Clear any existing environment variables
monkeypatch.delenv("TAVILY_API_KEY", raising=False)
# Set test environment variables
for key, value in env_vars.items():
monkeypatch.setenv(key, value)
result = validate_environment(args)
assert result == expected, f"Failed test case: {test_name}"

View File

@ -13,11 +13,11 @@ from ra_aid.tools.agent import request_research, request_implementation, request
# Read-only tools that don't modify system state # Read-only tools that don't modify system state
def get_read_only_tools(human_interaction: bool = False, web_research_enabled: bool = False) -> list: def get_read_only_tools(human_interaction: bool = False, web_research_enabled: bool = False) -> list:
"""Get the list of read-only tools, optionally including human interaction tools. """Get the list of read-only tools, optionally including human interaction tools.
Args: Args:
human_interaction: Whether to include human interaction tools human_interaction: Whether to include human interaction tools
web_research_enabled: Whether to include web research tools web_research_enabled: Whether to include web research tools
Returns: Returns:
List of tool functions List of tool functions
""" """
@ -37,10 +37,10 @@ def get_read_only_tools(human_interaction: bool = False, web_research_enabled: b
if web_research_enabled: if web_research_enabled:
tools.append(request_web_research) tools.append(request_web_research)
if human_interaction: if human_interaction:
tools.append(ask_human) tools.append(ask_human)
return tools return tools
# Define constant tool groups # Define constant tool groups
@ -58,7 +58,7 @@ RESEARCH_TOOLS = [
def get_research_tools(research_only: bool = False, expert_enabled: bool = True, human_interaction: bool = False, web_research_enabled: bool = False) -> list: def get_research_tools(research_only: bool = False, expert_enabled: bool = True, human_interaction: bool = False, web_research_enabled: bool = False) -> list:
"""Get the list of research tools based on mode and whether expert is enabled. """Get the list of research tools based on mode and whether expert is enabled.
Args: Args:
research_only: Whether to exclude modification tools research_only: Whether to exclude modification tools
expert_enabled: Whether to include expert tools expert_enabled: Whether to include expert tools
@ -67,33 +67,33 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True,
""" """
# Start with read-only tools # Start with read-only tools
tools = get_read_only_tools(human_interaction, web_research_enabled).copy() tools = get_read_only_tools(human_interaction, web_research_enabled).copy()
tools.extend(RESEARCH_TOOLS) tools.extend(RESEARCH_TOOLS)
# Add modification tools if not research_only # Add modification tools if not research_only
if not research_only: if not research_only:
tools.extend(MODIFICATION_TOOLS) tools.extend(MODIFICATION_TOOLS)
tools.append(request_implementation) tools.append(request_implementation)
# Add expert tools if enabled # Add expert tools if enabled
if expert_enabled: if expert_enabled:
tools.extend(EXPERT_TOOLS) tools.extend(EXPERT_TOOLS)
# Add chat-specific tools # Add chat-specific tools
tools.append(request_research) tools.append(request_research)
return tools return tools
def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list: def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
"""Get the list of planning tools based on whether expert is enabled. """Get the list of planning tools based on whether expert is enabled.
Args: Args:
expert_enabled: Whether to include expert tools expert_enabled: Whether to include expert tools
web_research_enabled: Whether to include web research tools web_research_enabled: Whether to include web research tools
""" """
# Start with read-only tools # Start with read-only tools
tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy() tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy()
# Add planning-specific tools # Add planning-specific tools
planning_tools = [ planning_tools = [
emit_plan, emit_plan,
@ -101,41 +101,43 @@ def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool =
plan_implementation_completed plan_implementation_completed
] ]
tools.extend(planning_tools) tools.extend(planning_tools)
# Add expert tools if enabled # Add expert tools if enabled
if expert_enabled: if expert_enabled:
tools.extend(EXPERT_TOOLS) tools.extend(EXPERT_TOOLS)
return tools return tools
def get_implementation_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list: def get_implementation_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
"""Get the list of implementation tools based on whether expert is enabled. """Get the list of implementation tools based on whether expert is enabled.
Args: Args:
expert_enabled: Whether to include expert tools expert_enabled: Whether to include expert tools
web_research_enabled: Whether to include web research tools web_research_enabled: Whether to include web research tools
""" """
# Start with read-only tools # Start with read-only tools
tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy() tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy()
# Add modification tools since it's not research-only # Add modification tools since it's not research-only
tools.extend(MODIFICATION_TOOLS) tools.extend(MODIFICATION_TOOLS)
tools.extend([ tools.extend([
task_completed task_completed
]) ])
# Add expert tools if enabled # Add expert tools if enabled
if expert_enabled: if expert_enabled:
tools.extend(EXPERT_TOOLS) tools.extend(EXPERT_TOOLS)
return tools return tools
def get_web_research_tools(expert_enabled: bool = True) -> list: def get_web_research_tools(expert_enabled: bool = True) -> list:
"""Get the list of tools available for web research. """Get the list of tools available for web research.
Args: Args:
expert_enabled: Whether expert tools should be included expert_enabled: Whether expert tools should be included
human_interaction: Whether to include human interaction tools
web_research_enabled: Whether to include web research tools
Returns: Returns:
list: List of tools configured for web research list: List of tools configured for web research
""" """
@ -153,10 +155,10 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list: def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
"""Get the list of tools available in chat mode. """Get the list of tools available in chat mode.
Chat mode includes research and implementation capabilities but excludes Chat mode includes research and implementation capabilities but excludes
complex planning tools. Human interaction is always enabled. complex planning tools. Human interaction is always enabled.
Args: Args:
expert_enabled: Whether to include expert tools expert_enabled: Whether to include expert tools
web_research_enabled: Whether to include web research tools web_research_enabled: Whether to include web research tools

View File

@ -0,0 +1,113 @@
"""Tests for default provider and model configuration."""
import os
import pytest
from dataclasses import dataclass
from typing import Optional
from ra_aid.env import validate_environment
from ra_aid.__main__ import parse_arguments
@dataclass
class MockArgs:
"""Mock arguments for testing."""
provider: str
expert_provider: Optional[str] = None
model: Optional[str] = None
expert_model: Optional[str] = None
message: Optional[str] = None
research_only: bool = False
chat: bool = False
@pytest.fixture
def clean_env(monkeypatch):
"""Remove all provider-related environment variables."""
env_vars = [
"ANTHROPIC_API_KEY",
"OPENAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENAI_API_BASE",
"EXPERT_ANTHROPIC_API_KEY",
"EXPERT_OPENAI_API_KEY",
"EXPERT_OPENAI_API_BASE",
"TAVILY_API_KEY",
"ANTHROPIC_MODEL",
]
for var in env_vars:
monkeypatch.delenv(var, raising=False)
yield
def test_default_anthropic_provider(clean_env, monkeypatch):
"""Test that Anthropic is the default provider when no environment variables are set."""
args = parse_arguments(["-m", "test message"])
assert args.provider == "anthropic"
assert args.model == "claude-3-5-sonnet-20241022"
"""Unit tests for provider and model validation in research-only mode."""
import pytest
from dataclasses import dataclass
from argparse import Namespace
from ra_aid.env import validate_environment
@dataclass
class MockArgs:
"""Mock command line arguments."""
research_only: bool = False
provider: str = None
model: str = None
expert_provider: str = None
TEST_CASES = [
pytest.param(
"research_only_no_provider",
MockArgs(research_only=True),
{},
"No provider specified",
id="research_only_no_provider"
),
pytest.param(
"research_only_anthropic",
MockArgs(research_only=True, provider="anthropic"),
{},
None,
id="research_only_anthropic"
),
pytest.param(
"research_only_non_anthropic_no_model",
MockArgs(research_only=True, provider="openai"),
{},
"Model is required for non-Anthropic providers",
id="research_only_non_anthropic_no_model"
),
pytest.param(
"research_only_non_anthropic_with_model",
MockArgs(research_only=True, provider="openai", model="gpt-4"),
{},
None,
id="research_only_non_anthropic_with_model"
)
]
@pytest.mark.parametrize("test_name,args,env_vars,expected_error", TEST_CASES)
def test_research_only_provider_validation(
test_name: str,
args: MockArgs,
env_vars: dict,
expected_error: str,
monkeypatch
):
"""Test provider and model validation in research-only mode."""
# Set test environment variables
for key, value in env_vars.items():
monkeypatch.setenv(key, value)
if expected_error:
with pytest.raises(SystemExit, match=expected_error):
validate_environment(args)
else:
validate_environment(args)

View File

@ -18,33 +18,32 @@ def clean_env(monkeypatch):
env_vars = [ env_vars = [
'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY', 'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY',
'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY', 'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY',
'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY' 'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY', 'ANTHROPIC_MODEL'
] ]
for var in env_vars: for var in env_vars:
monkeypatch.delenv(var, raising=False) monkeypatch.delenv(var, raising=False)
def test_anthropic_validation(clean_env, monkeypatch): def test_anthropic_validation(clean_env, monkeypatch):
args = MockArgs(provider="anthropic", expert_provider="openai") args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307")
# Should fail without API key # Should fail without API key
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
# Should pass with API key # Should pass with API key and model
monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key') monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key')
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert not expert_enabled assert not expert_enabled
assert 'EXPERT_OPENAI_API_KEY environment variable is not set' in expert_missing
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
def test_openai_validation(clean_env, monkeypatch): def test_openai_validation(clean_env, monkeypatch):
args = MockArgs(provider="openai", expert_provider="openai") args = MockArgs(provider="openai", expert_provider="openai")
# Should fail without API key # Should fail without API key
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
# Should pass with API key and enable expert mode with fallback # Should pass with API key and enable expert mode with fallback
monkeypatch.setenv('OPENAI_API_KEY', 'test-key') monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
@ -56,16 +55,16 @@ def test_openai_validation(clean_env, monkeypatch):
def test_openai_compatible_validation(clean_env, monkeypatch): def test_openai_compatible_validation(clean_env, monkeypatch):
args = MockArgs(provider="openai-compatible", expert_provider="openai-compatible") args = MockArgs(provider="openai-compatible", expert_provider="openai-compatible")
# Should fail without API key and base URL # Should fail without API key and base URL
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
# Should fail with only API key # Should fail with only API key
monkeypatch.setenv('OPENAI_API_KEY', 'test-key') monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
# Should pass with both API key and base URL # Should pass with both API key and base URL
monkeypatch.setenv('OPENAI_API_BASE', 'http://test') monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
@ -78,10 +77,10 @@ def test_openai_compatible_validation(clean_env, monkeypatch):
def test_expert_fallback(clean_env, monkeypatch): def test_expert_fallback(clean_env, monkeypatch):
args = MockArgs(provider="openai", expert_provider="openai") args = MockArgs(provider="openai", expert_provider="openai")
# Set only base API key # Set only base API key
monkeypatch.setenv('OPENAI_API_KEY', 'test-key') monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
# Should enable expert mode with fallback # Should enable expert mode with fallback
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled assert expert_enabled
@ -89,7 +88,7 @@ def test_expert_fallback(clean_env, monkeypatch):
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
# Should use explicit expert key if available # Should use explicit expert key if available
monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key') monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key')
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
@ -101,25 +100,25 @@ def test_expert_fallback(clean_env, monkeypatch):
def test_cross_provider_fallback(clean_env, monkeypatch): def test_cross_provider_fallback(clean_env, monkeypatch):
"""Test that fallback works even when providers differ""" """Test that fallback works even when providers differ"""
args = MockArgs(provider="openai", expert_provider="anthropic") args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
# Set base API key for main provider and expert provider # Set base API key for main provider and expert provider
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
monkeypatch.setenv('ANTHROPIC_MODEL', 'claude-3-haiku-20240307')
# Should enable expert mode with fallback to ANTHROPIC base key # Should enable expert mode with fallback to ANTHROPIC base key
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled assert expert_enabled
assert not expert_missing assert not expert_missing
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_ANTHROPIC_API_KEY') == 'anthropic-key'
# Try with openai-compatible expert provider # Try with openai-compatible expert provider
args = MockArgs(provider="anthropic", expert_provider="openai-compatible") args = MockArgs(provider="anthropic", expert_provider="openai-compatible")
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
monkeypatch.setenv('OPENAI_API_BASE', 'http://test') monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled assert expert_enabled
assert not expert_missing assert not expert_missing
@ -131,43 +130,51 @@ def test_cross_provider_fallback(clean_env, monkeypatch):
def test_no_warning_on_fallback(clean_env, monkeypatch): def test_no_warning_on_fallback(clean_env, monkeypatch):
"""Test that no warning is issued when fallback succeeds""" """Test that no warning is issued when fallback succeeds"""
args = MockArgs(provider="openai", expert_provider="openai") args = MockArgs(provider="openai", expert_provider="openai")
# Set only base API key # Set only base API key
monkeypatch.setenv('OPENAI_API_KEY', 'test-key') monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
# Should enable expert mode with fallback and no warnings # Should enable expert mode with fallback and no warnings
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled assert expert_enabled
assert not expert_missing # List should be empty assert not expert_missing
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
# Should use explicit expert key if available
monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key')
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled
assert not expert_missing
assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'expert-key'
def test_different_providers_no_expert_key(clean_env, monkeypatch): def test_different_providers_no_expert_key(clean_env, monkeypatch):
"""Test behavior when providers differ and only base keys are available""" """Test behavior when providers differ and only base keys are available"""
args = MockArgs(provider="anthropic", expert_provider="openai") args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307")
# Set only base keys # Set only base keys
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
# Should enable expert mode and use base OPENAI key # Should enable expert mode and use base OPENAI key
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled assert expert_enabled
assert not expert_missing assert not expert_missing
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'openai-key'
def test_mixed_provider_openai_compatible(clean_env, monkeypatch): def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
"""Test behavior with openai-compatible expert and different main provider""" """Test behavior with openai-compatible expert and different main provider"""
args = MockArgs(provider="anthropic", expert_provider="openai-compatible") args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")
# Set all required keys and URLs # Set all required keys and URLs
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
monkeypatch.setenv('OPENAI_API_BASE', 'http://test') monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
# Should enable expert mode and use base openai key and URL # Should enable expert mode and use base openai key and URL
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
assert expert_enabled assert expert_enabled

View File

@ -207,14 +207,13 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert not expert_enabled assert not expert_enabled
assert len(expert_missing) == 1 assert expert_missing
assert expert_missing[0] == "EXPERT_OPENAI_API_KEY environment variable is not set"
assert not web_enabled assert not web_enabled
assert len(web_missing) == 1 assert web_missing
assert web_missing[0] == "TAVILY_API_KEY environment variable is not set"
@pytest.fixture @pytest.fixture
def mock_anthropic(): def mock_anthropic():

View File

@ -0,0 +1,259 @@
"""Integration tests for provider validation and environment handling."""
import os
import pytest
from dataclasses import dataclass
from typing import Optional
from ra_aid.env import validate_environment
from ra_aid.provider_strategy import (
ProviderFactory,
ValidationResult,
AnthropicStrategy,
OpenAIStrategy,
OpenAICompatibleStrategy,
OpenRouterStrategy
)
@dataclass
class MockArgs:
"""Mock arguments for testing."""
provider: str
expert_provider: Optional[str] = None
model: Optional[str] = None
expert_model: Optional[str] = None
@pytest.fixture
def clean_env():
"""Remove all provider-related environment variables."""
env_vars = [
"ANTHROPIC_API_KEY",
"OPENAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENAI_API_BASE",
"EXPERT_ANTHROPIC_API_KEY",
"EXPERT_OPENAI_API_KEY",
"EXPERT_OPENAI_API_BASE",
"TAVILY_API_KEY",
"ANTHROPIC_MODEL"
]
# Store original values
original_values = {}
for var in env_vars:
original_values[var] = os.environ.get(var)
if var in os.environ:
del os.environ[var]
yield
# Restore original values
for var, value in original_values.items():
if value is not None:
os.environ[var] = value
elif var in os.environ:
del os.environ[var]
def test_provider_validation_respects_cli_args(clean_env):
"""Test that provider validation respects CLI args over defaults."""
# Set up environment with only OpenAI credentials
os.environ["OPENAI_API_KEY"] = "test-key"
# Should succeed with OpenAI provider
args = MockArgs(provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert not expert_missing
# Should fail with Anthropic provider even though it's first alphabetically
args = MockArgs(provider="anthropic", model="claude-3-haiku-20240307")
with pytest.raises(SystemExit):
validate_environment(args)
def test_expert_provider_fallback(clean_env):
"""Test expert provider falls back to main provider keys."""
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled
assert not expert_missing
def test_openai_compatible_base_url(clean_env):
"""Test OpenAI-compatible provider requires base URL."""
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai-compatible")
with pytest.raises(SystemExit):
validate_environment(args)
os.environ["OPENAI_API_BASE"] = "http://test"
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_missing
def test_expert_provider_separate_keys(clean_env):
"""Test expert provider can use separate keys."""
os.environ["OPENAI_API_KEY"] = "main-key"
os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key"
args = MockArgs(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled
assert not expert_missing
def test_web_research_independent(clean_env):
"""Test web research validation is independent of provider."""
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai")
# Without Tavily key
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert not web_enabled
assert web_missing
# With Tavily key
os.environ["TAVILY_API_KEY"] = "test-key"
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert web_enabled
assert not web_missing
def test_provider_factory_unknown_provider(clean_env):
"""Test provider factory handles unknown providers."""
strategy = ProviderFactory.create("unknown")
assert strategy is None
args = MockArgs(provider="unknown")
with pytest.raises(SystemExit):
validate_environment(args)
def test_provider_strategy_validation(clean_env):
"""Test individual provider strategies."""
# Test Anthropic strategy
strategy = AnthropicStrategy()
result = strategy.validate()
assert not result.valid
assert "ANTHROPIC_API_KEY environment variable is not set" in result.missing_vars
os.environ["ANTHROPIC_API_KEY"] = "test-key"
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
result = strategy.validate()
assert result.valid
assert not result.missing_vars
# Test OpenAI strategy
strategy = OpenAIStrategy()
result = strategy.validate()
assert not result.valid
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
os.environ["OPENAI_API_KEY"] = "test-key"
result = strategy.validate()
assert result.valid
assert not result.missing_vars
def test_missing_provider_arg():
"""Test handling of missing provider argument."""
with pytest.raises(SystemExit):
validate_environment(None)
with pytest.raises(SystemExit):
validate_environment(MockArgs(provider=None))
def test_empty_provider_arg():
"""Test handling of empty provider argument."""
with pytest.raises(SystemExit):
validate_environment(MockArgs(provider=""))
def test_incomplete_openai_compatible_config(clean_env):
"""Test OpenAI-compatible provider with incomplete configuration."""
strategy = OpenAICompatibleStrategy()
# No configuration
result = strategy.validate()
assert not result.valid
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
# Only API key
os.environ["OPENAI_API_KEY"] = "test-key"
result = strategy.validate()
assert not result.valid
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
# Only base URL
os.environ.pop("OPENAI_API_KEY")
os.environ["OPENAI_API_BASE"] = "http://test"
result = strategy.validate()
assert not result.valid
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
def test_incomplete_expert_config(clean_env):
"""Test expert provider with incomplete configuration."""
# Set main provider but not expert
os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai", expert_provider="openai-compatible")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled
assert len(expert_missing) == 1
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
# Set expert key but not base URL
os.environ["EXPERT_OPENAI_API_KEY"] = "test-key"
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled
assert len(expert_missing) == 1
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
def test_empty_environment_variables(clean_env):
"""Test handling of empty environment variables."""
# Empty API key
os.environ["OPENAI_API_KEY"] = ""
args = MockArgs(provider="openai")
with pytest.raises(SystemExit):
validate_environment(args)
# Empty base URL
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["OPENAI_API_BASE"] = ""
args = MockArgs(provider="openai-compatible")
with pytest.raises(SystemExit):
validate_environment(args)
def test_openrouter_validation(clean_env):
"""Test OpenRouter provider validation."""
strategy = OpenRouterStrategy()
# No API key
result = strategy.validate()
assert not result.valid
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
# Empty API key
os.environ["OPENROUTER_API_KEY"] = ""
result = strategy.validate()
assert not result.valid
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
# Valid API key
os.environ["OPENROUTER_API_KEY"] = "test-key"
result = strategy.validate()
assert result.valid
assert not result.missing_vars
def test_multiple_expert_providers(clean_env):
"""Test validation with multiple expert providers."""
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["ANTHROPIC_API_KEY"] = "test-key"
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
# First expert provider valid, second invalid
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled
assert not expert_missing
# Switch to invalid provider
args = MockArgs(provider="openai", expert_provider="openai-compatible")
expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled
assert expert_missing