FEAT fix command line args and env var dependencies on anthropic (#21)
* FIX provider cmdline args * FIX Issue 18 * FIX ensure research-only requires a model
This commit is contained in:
parent
9944ec9ea4
commit
8b3f4d736c
|
|
@ -5,13 +5,12 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.10.3] - 1024-12-27
|
||||
## [0.10.3] - 2024-12-27
|
||||
|
||||
- Fix logging on interrupt.
|
||||
- Fix web research prompt.
|
||||
- Simplify planning stage by executing tasks directly.
|
||||
- Make research notes available to more agents/tools.
|
||||
- Make read_file always output status panel.
|
||||
|
||||
## [0.10.2] - 2024-12-26
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
.PHONY: test setup-dev setup-hooks
|
||||
|
||||
test:
|
||||
python -m pytest
|
||||
|
||||
setup-dev:
|
||||
pip install -e ".[dev]"
|
||||
|
||||
setup-hooks: setup-dev
|
||||
pre-commit install
|
||||
|
|
@ -13,11 +13,11 @@ keywords = ["langchain", "ai", "agent", "tools", "development"]
|
|||
authors = [{name = "AI Christianson", email = "ai.christianson@christianson.ai"}]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules"
|
||||
]
|
||||
|
|
@ -31,7 +31,7 @@ dependencies = [
|
|||
"langchain>=0.3.13",
|
||||
"rich>=13.0.0",
|
||||
"GitPython>=3.1",
|
||||
"fuzzywuzzy==0.18.0",
|
||||
"fuzzywuzzy==0.18.0",
|
||||
"python-Levenshtein==0.23.0",
|
||||
"pathspec>=0.11.0",
|
||||
"aider-chat>=0.69.1",
|
||||
|
|
|
|||
|
|
@ -26,10 +26,14 @@ from ra_aid.logging_config import setup_logging, get_logger
|
|||
from ra_aid.tool_configs import (
|
||||
get_chat_tools
|
||||
)
|
||||
import os
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def parse_arguments():
|
||||
def parse_arguments(args=None):
|
||||
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible']
|
||||
ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022'
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='RA.Aid - AI Agent for executing programming and research tasks',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
|
|
@ -59,7 +63,7 @@ Examples:
|
|||
'--provider',
|
||||
type=str,
|
||||
default='anthropic',
|
||||
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'],
|
||||
choices=VALID_PROVIDERS,
|
||||
help='The LLM provider to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -76,7 +80,7 @@ Examples:
|
|||
'--expert-provider',
|
||||
type=str,
|
||||
default='openai',
|
||||
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'],
|
||||
choices=VALID_PROVIDERS,
|
||||
help='The LLM provider to use for expert knowledge queries (default: openai)'
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
@ -99,25 +103,32 @@ Examples:
|
|||
action='store_true',
|
||||
help='Enable verbose logging output'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
parsed_args = parser.parse_args(args)
|
||||
|
||||
# Set hil=True when chat mode is enabled
|
||||
if args.chat:
|
||||
args.hil = True
|
||||
|
||||
# Set default model for Anthropic, require model for other providers
|
||||
if args.provider == 'anthropic':
|
||||
if not args.model:
|
||||
args.model = 'claude-3-5-sonnet-20241022'
|
||||
elif not args.model:
|
||||
parser.error(f"--model is required when using provider '{args.provider}'")
|
||||
|
||||
if parsed_args.chat:
|
||||
parsed_args.hil = True
|
||||
|
||||
# Validate provider
|
||||
if parsed_args.provider not in VALID_PROVIDERS:
|
||||
parser.error(f"Invalid provider: {parsed_args.provider}")
|
||||
|
||||
# Handle model defaults and requirements
|
||||
if parsed_args.provider == 'anthropic':
|
||||
# Always use default model for Anthropic
|
||||
parsed_args.model = ANTHROPIC_DEFAULT_MODEL
|
||||
elif not parsed_args.model and not parsed_args.research_only:
|
||||
# Require model for other providers unless in research mode
|
||||
parser.error(f"--model is required when using provider '{parsed_args.provider}'")
|
||||
|
||||
# Validate expert model requirement
|
||||
if args.expert_provider != 'openai' and not args.expert_model:
|
||||
parser.error(f"--expert-model is required when using expert provider '{args.expert_provider}'")
|
||||
|
||||
return args
|
||||
if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only:
|
||||
parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'")
|
||||
|
||||
return parsed_args
|
||||
|
||||
# Create console instance
|
||||
console = Console()
|
||||
|
|
@ -143,14 +154,14 @@ def main():
|
|||
args = parse_arguments()
|
||||
setup_logging(args.verbose)
|
||||
logger.debug("Starting RA.Aid with arguments: %s", args)
|
||||
|
||||
|
||||
try:
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
|
||||
|
||||
if expert_missing:
|
||||
console.print(Panel(
|
||||
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" +
|
||||
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" +
|
||||
"\n".join(f"- {m}" for m in expert_missing) +
|
||||
"\nSet the required environment variables or args to enable expert mode.",
|
||||
title="Expert Tools Disabled",
|
||||
|
|
@ -159,20 +170,24 @@ def main():
|
|||
|
||||
if web_research_missing:
|
||||
console.print(Panel(
|
||||
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" +
|
||||
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" +
|
||||
"\n".join(f"- {m}" for m in web_research_missing) +
|
||||
"\nSet the required environment variables to enable web research.",
|
||||
title="Web Research Disabled",
|
||||
style="yellow"
|
||||
))
|
||||
|
||||
|
||||
# Create the base model after validation
|
||||
model = initialize_llm(args.provider, args.model)
|
||||
|
||||
# Handle chat mode
|
||||
if args.chat:
|
||||
if args.research_only:
|
||||
print_error("Chat mode cannot be used with --research-only")
|
||||
sys.exit(1)
|
||||
|
||||
print_stage_header("Chat Mode")
|
||||
|
||||
|
||||
# Get initial request from user
|
||||
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
||||
|
||||
|
|
@ -182,7 +197,7 @@ def main():
|
|||
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
||||
checkpointer=MemorySaver()
|
||||
)
|
||||
|
||||
|
||||
# Run chat agent with CHAT_PROMPT
|
||||
config = {
|
||||
"configurable": {"thread_id": uuid.uuid4()},
|
||||
|
|
@ -193,14 +208,14 @@ def main():
|
|||
"web_research_enabled": web_research_enabled,
|
||||
"initial_request": initial_request
|
||||
}
|
||||
|
||||
|
||||
# Store config in global memory
|
||||
_global_memory['config'] = config
|
||||
_global_memory['config']['provider'] = args.provider
|
||||
_global_memory['config']['model'] = args.model
|
||||
_global_memory['config']['expert_provider'] = args.expert_provider
|
||||
_global_memory['config']['expert_model'] = args.expert_model
|
||||
|
||||
|
||||
# Run chat agent and exit
|
||||
run_agent_with_retry(chat_agent, CHAT_PROMPT.format(
|
||||
initial_request=initial_request,
|
||||
|
|
@ -212,7 +227,7 @@ def main():
|
|||
if not args.message:
|
||||
print_error("--message is required")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
base_task = args.message
|
||||
config = {
|
||||
"configurable": {"thread_id": uuid.uuid4()},
|
||||
|
|
@ -221,21 +236,21 @@ def main():
|
|||
"cowboy_mode": args.cowboy_mode,
|
||||
"web_research_enabled": web_research_enabled
|
||||
}
|
||||
|
||||
|
||||
# Store config in global memory for access by is_informational_query
|
||||
_global_memory['config'] = config
|
||||
|
||||
|
||||
# Store model configuration
|
||||
_global_memory['config']['provider'] = args.provider
|
||||
_global_memory['config']['model'] = args.model
|
||||
|
||||
|
||||
# Store expert provider and model in config
|
||||
_global_memory['config']['expert_provider'] = args.expert_provider
|
||||
_global_memory['config']['expert_model'] = args.expert_model
|
||||
|
||||
|
||||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
||||
|
||||
run_research_agent(
|
||||
base_task,
|
||||
model,
|
||||
|
|
@ -245,7 +260,7 @@ def main():
|
|||
memory=research_memory,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
# Proceed with planning and implementation if not an informational query
|
||||
if not is_informational_query():
|
||||
# Run planning agent
|
||||
|
|
|
|||
|
|
@ -162,9 +162,24 @@ def run_research_agent(
|
|||
if console_message:
|
||||
console.print(Panel(Markdown(console_message), title="🔬 Looking into it..."))
|
||||
|
||||
# Run agent with retry logic
|
||||
logger.debug("Research agent completed successfully")
|
||||
return run_agent_with_retry(agent, prompt, run_config)
|
||||
# Run agent with retry logic if available
|
||||
if agent is not None:
|
||||
logger.debug("Research agent completed successfully")
|
||||
return run_agent_with_retry(agent, prompt, run_config)
|
||||
else:
|
||||
# Just run web research tools directly if no agent
|
||||
logger.debug("No model provided, running web research tools directly")
|
||||
return run_web_research_agent(
|
||||
base_task_or_query,
|
||||
model=None,
|
||||
expert_enabled=expert_enabled,
|
||||
hil=hil,
|
||||
web_research_enabled=web_research_enabled,
|
||||
memory=memory,
|
||||
config=config,
|
||||
thread_id=thread_id,
|
||||
console_message=console_message
|
||||
)
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -255,11 +270,41 @@ def run_web_research_agent(
|
|||
try:
|
||||
# Display console message if provided
|
||||
if console_message:
|
||||
console.print(Panel(Markdown(console_message), title="🔍 Starting Web Research..."))
|
||||
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
|
||||
|
||||
# Run agent with retry logic if available
|
||||
if agent is not None:
|
||||
logger.debug("Web research agent completed successfully")
|
||||
return run_agent_with_retry(agent, prompt, run_config)
|
||||
else:
|
||||
# Just use the web research tools directly
|
||||
logger.debug("No model provided, using web research tools directly")
|
||||
tavily_tool = next((tool for tool in tools if tool.name == 'web_search_tavily'), None)
|
||||
|
||||
if not tavily_tool:
|
||||
return "No web research results found"
|
||||
|
||||
result = tavily_tool.invoke({"query": query})
|
||||
if not result:
|
||||
return "No web research results found"
|
||||
|
||||
# Format Tavily results
|
||||
markdown_result = "# Search Results\n\n"
|
||||
for item in result.get('results', []):
|
||||
title = item.get('title', 'Untitled')
|
||||
url = item.get('url', '')
|
||||
content = item.get('content', '')
|
||||
score = item.get('score', 0)
|
||||
|
||||
markdown_result += f"## {title}\n"
|
||||
markdown_result += f"**Score**: {score:.2f}\n\n"
|
||||
markdown_result += f"{content}\n\n"
|
||||
markdown_result += f"[Read more]({url})\n\n"
|
||||
markdown_result += "---\n\n"
|
||||
|
||||
console.print(Panel(Markdown(markdown_result), title="🔍 Web Research Results"))
|
||||
return markdown_result
|
||||
|
||||
# Run agent with retry logic
|
||||
logger.debug("Web research agent completed successfully")
|
||||
return run_agent_with_retry(agent, prompt, run_config)
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
|
|||
270
ra_aid/env.py
270
ra_aid/env.py
|
|
@ -3,103 +3,209 @@
|
|||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List
|
||||
from typing import Tuple, List, Any
|
||||
|
||||
from ra_aid import print_error
|
||||
from ra_aid.provider_strategy import ProviderFactory, ValidationResult
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""Configuration for a provider."""
|
||||
key_name: str
|
||||
base_required: bool = False
|
||||
class ValidationResult:
|
||||
"""Result of validation."""
|
||||
valid: bool
|
||||
missing_vars: List[str]
|
||||
|
||||
PROVIDER_CONFIGS = {
|
||||
"anthropic": ProviderConfig("ANTHROPIC_API_KEY", base_required=True),
|
||||
"openai": ProviderConfig("OPENAI_API_KEY", base_required=True),
|
||||
"openrouter": ProviderConfig("OPENROUTER_API_KEY", base_required=True),
|
||||
"openai-compatible": ProviderConfig("OPENAI_API_KEY", base_required=True),
|
||||
}
|
||||
def validate_provider(provider: str) -> ValidationResult:
|
||||
"""Validate provider configuration."""
|
||||
if not provider:
|
||||
return ValidationResult(valid=False, missing_vars=["No provider specified"])
|
||||
strategy = ProviderFactory.create(provider)
|
||||
if not strategy:
|
||||
return ValidationResult(valid=False, missing_vars=[f"Unknown provider: {provider}"])
|
||||
return strategy.validate()
|
||||
|
||||
def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
|
||||
"""Copy base provider environment variables to expert provider if not set.
|
||||
|
||||
def validate_environment(args) -> Tuple[bool, List[str], bool, List[str]]:
|
||||
"""Validate required environment variables and dependencies.
|
||||
|
||||
Args:
|
||||
args: The parsed command line arguments containing:
|
||||
- provider: The main LLM provider
|
||||
- expert_provider: The expert LLM provider
|
||||
base_provider: Base provider name
|
||||
expert_provider: Expert provider name
|
||||
"""
|
||||
# Map of base to expert environment variables for each provider
|
||||
provider_vars = {
|
||||
'openai': {
|
||||
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
|
||||
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
|
||||
},
|
||||
'openai-compatible': {
|
||||
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
|
||||
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
|
||||
},
|
||||
'anthropic': {
|
||||
'ANTHROPIC_API_KEY': 'EXPERT_ANTHROPIC_API_KEY',
|
||||
'ANTHROPIC_MODEL': 'EXPERT_ANTHROPIC_MODEL'
|
||||
},
|
||||
'openrouter': {
|
||||
'OPENROUTER_API_KEY': 'EXPERT_OPENROUTER_API_KEY'
|
||||
}
|
||||
}
|
||||
|
||||
# Get the variables to copy based on the expert provider
|
||||
vars_to_copy = provider_vars.get(expert_provider, {})
|
||||
for base_var, expert_var in vars_to_copy.items():
|
||||
# Only copy if expert var is not set and base var exists
|
||||
if not os.environ.get(expert_var) and os.environ.get(base_var):
|
||||
os.environ[expert_var] = os.environ[base_var]
|
||||
|
||||
def validate_expert_provider(provider: str) -> ValidationResult:
|
||||
"""Validate expert provider configuration with fallback."""
|
||||
if not provider:
|
||||
return ValidationResult(valid=True, missing_vars=[])
|
||||
|
||||
strategy = ProviderFactory.create(provider)
|
||||
if not strategy:
|
||||
return ValidationResult(valid=False, missing_vars=[f"Unknown expert provider: {provider}"])
|
||||
|
||||
# Copy base vars to expert vars for fallback
|
||||
copy_base_to_expert_vars(provider, provider)
|
||||
|
||||
# Validate expert configuration
|
||||
result = strategy.validate()
|
||||
missing = []
|
||||
|
||||
for var in result.missing_vars:
|
||||
key = var.split()[0] # Get the key name without the error message
|
||||
expert_key = f"EXPERT_{key}"
|
||||
if not os.environ.get(expert_key):
|
||||
missing.append(f"{expert_key} environment variable is not set")
|
||||
|
||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
def validate_web_research() -> ValidationResult:
|
||||
"""Validate web research configuration."""
|
||||
key = "TAVILY_API_KEY"
|
||||
return ValidationResult(
|
||||
valid=bool(os.environ.get(key)),
|
||||
missing_vars=[] if os.environ.get(key) else [f"{key} environment variable is not set"]
|
||||
)
|
||||
|
||||
def print_missing_dependencies(missing_vars: List[str]) -> None:
|
||||
"""Print missing dependencies and exit."""
|
||||
for var in missing_vars:
|
||||
print(f"Error: {var}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
def validate_research_only_provider(args: Any) -> None:
|
||||
"""Validate provider and model for research-only mode.
|
||||
|
||||
Args:
|
||||
args: Arguments containing provider and expert provider settings
|
||||
|
||||
Raises:
|
||||
SystemExit: If provider or model validation fails
|
||||
"""
|
||||
# Get provider from args
|
||||
provider = args.provider if args and hasattr(args, 'provider') else None
|
||||
if not provider:
|
||||
sys.exit("No provider specified")
|
||||
|
||||
# For non-Anthropic providers in research-only mode, model must be specified
|
||||
if provider != 'anthropic':
|
||||
model = args.model if hasattr(args, 'model') and args.model else None
|
||||
if not model:
|
||||
sys.exit("Model is required for non-Anthropic providers")
|
||||
|
||||
def validate_research_only(args: Any) -> tuple[bool, list[str], bool, list[str]]:
|
||||
"""Validate environment variables for research-only mode.
|
||||
|
||||
Args:
|
||||
args: Arguments containing provider and expert provider settings
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- bool: Whether expert mode is enabled
|
||||
- List[str]: List of missing expert configuration items
|
||||
- bool: Whether web research is enabled
|
||||
- List[str]: List of missing web research configuration items
|
||||
|
||||
Raises:
|
||||
SystemExit: If required base environment variables are missing
|
||||
- expert_enabled: Whether expert mode is enabled
|
||||
- expert_missing: List of missing expert dependencies
|
||||
- web_research_enabled: Whether web research is enabled
|
||||
- web_research_missing: List of missing web research dependencies
|
||||
"""
|
||||
missing = []
|
||||
provider = args.provider
|
||||
expert_provider = args.expert_provider
|
||||
|
||||
# Check API keys based on provider configs
|
||||
if provider in PROVIDER_CONFIGS:
|
||||
config = PROVIDER_CONFIGS[provider]
|
||||
if config.base_required and not os.environ.get(config.key_name):
|
||||
missing.append(f'{config.key_name} environment variable is not set')
|
||||
|
||||
# Special case for openai-compatible needing base URL
|
||||
if provider == "openai-compatible" and not os.environ.get('OPENAI_API_BASE'):
|
||||
missing.append('OPENAI_API_BASE environment variable is not set')
|
||||
|
||||
# Initialize results
|
||||
expert_enabled = False
|
||||
expert_missing = []
|
||||
if expert_provider in PROVIDER_CONFIGS:
|
||||
config = PROVIDER_CONFIGS[expert_provider]
|
||||
expert_key = f'EXPERT_{config.key_name}'
|
||||
expert_key_missing = not os.environ.get(expert_key)
|
||||
|
||||
# Try fallback to base key for expert provider
|
||||
fallback_available = os.environ.get(config.key_name)
|
||||
if expert_key_missing and fallback_available:
|
||||
os.environ[expert_key] = os.environ[config.key_name]
|
||||
expert_key_missing = False
|
||||
|
||||
# Only add to missing list if still missing after fallback attempt
|
||||
if expert_key_missing:
|
||||
expert_missing.append(f'{expert_key} environment variable is not set')
|
||||
|
||||
# Special case for openai-compatible expert needing base URL
|
||||
if expert_provider == "openai-compatible":
|
||||
expert_base = 'EXPERT_OPENAI_API_BASE'
|
||||
base_missing = not os.environ.get(expert_base)
|
||||
base_fallback = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
if base_missing and base_fallback:
|
||||
os.environ[expert_base] = os.environ['OPENAI_API_BASE']
|
||||
base_missing = False
|
||||
|
||||
if base_missing:
|
||||
expert_missing.append(f'{expert_base} environment variable is not set')
|
||||
|
||||
# If main keys missing, we must exit immediately
|
||||
if missing:
|
||||
print_error("Missing required dependencies:")
|
||||
for item in missing:
|
||||
print_error(f"- {item}")
|
||||
sys.exit(1)
|
||||
|
||||
# If expert keys missing, we disable expert tools instead of exiting
|
||||
expert_enabled = True
|
||||
if expert_missing:
|
||||
expert_enabled = False
|
||||
|
||||
# Check web research dependencies
|
||||
web_research_missing = []
|
||||
web_research_enabled = False
|
||||
|
||||
if not os.environ.get('TAVILY_API_KEY'):
|
||||
web_research_missing = []
|
||||
|
||||
# Validate web research dependencies
|
||||
tavily_key = os.environ.get('TAVILY_API_KEY')
|
||||
if not tavily_key:
|
||||
web_research_missing.append('TAVILY_API_KEY environment variable is not set')
|
||||
else:
|
||||
web_research_enabled = True
|
||||
|
||||
return expert_enabled, expert_missing, web_research_enabled, web_research_missing
|
||||
|
||||
def validate_environment(args: Any) -> tuple[bool, list[str], bool, list[str]]:
|
||||
"""Validate environment variables for providers and web research tools.
|
||||
|
||||
Args:
|
||||
args: Arguments containing provider and expert provider settings
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- expert_enabled: Whether expert mode is enabled
|
||||
- expert_missing: List of missing expert dependencies
|
||||
- web_research_enabled: Whether web research is enabled
|
||||
- web_research_missing: List of missing web research dependencies
|
||||
"""
|
||||
# For research-only mode, use separate validation
|
||||
if hasattr(args, 'research_only') and args.research_only:
|
||||
# Only validate provider and model when testing provider validation
|
||||
if hasattr(args, 'model') and args.model is None:
|
||||
validate_research_only_provider(args)
|
||||
return validate_research_only(args)
|
||||
|
||||
# Initialize results
|
||||
expert_enabled = False
|
||||
expert_missing = []
|
||||
web_research_enabled = False
|
||||
web_research_missing = []
|
||||
|
||||
# Get provider from args
|
||||
provider = args.provider if args and hasattr(args, 'provider') else None
|
||||
if not provider:
|
||||
sys.exit("No provider specified")
|
||||
|
||||
# Validate main provider
|
||||
strategy = ProviderFactory.create(provider, args)
|
||||
if not strategy:
|
||||
sys.exit(f"Unknown provider: {provider}")
|
||||
|
||||
result = strategy.validate(args)
|
||||
if not result.valid:
|
||||
print_missing_dependencies(result.missing_vars)
|
||||
|
||||
# Handle expert provider if enabled
|
||||
if args.expert_provider:
|
||||
# Copy base variables to expert if not set
|
||||
copy_base_to_expert_vars(provider, args.expert_provider)
|
||||
|
||||
# Validate expert provider
|
||||
expert_strategy = ProviderFactory.create(args.expert_provider, args)
|
||||
if not expert_strategy:
|
||||
sys.exit(f"Unknown expert provider: {args.expert_provider}")
|
||||
|
||||
expert_result = expert_strategy.validate(args)
|
||||
expert_missing = expert_result.missing_vars
|
||||
expert_enabled = len(expert_missing) == 0
|
||||
|
||||
# If expert validation failed, try to copy base variables again and revalidate
|
||||
if not expert_enabled:
|
||||
copy_base_to_expert_vars(provider, args.expert_provider)
|
||||
expert_result = expert_strategy.validate(args)
|
||||
expert_missing = expert_result.missing_vars
|
||||
expert_enabled = len(expert_missing) == 0
|
||||
|
||||
# Validate web research dependencies
|
||||
web_result = validate_web_research()
|
||||
web_research_enabled = web_result.valid
|
||||
web_research_missing = web_result.missing_vars
|
||||
|
||||
return expert_enabled, expert_missing, web_research_enabled, web_research_missing
|
||||
|
|
|
|||
|
|
@ -0,0 +1,237 @@
|
|||
"""Provider validation strategies."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Any
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of validation."""
|
||||
valid: bool
|
||||
missing_vars: List[str]
|
||||
|
||||
class ProviderStrategy(ABC):
|
||||
"""Abstract base class for provider validation strategies."""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||
"""Validate provider environment variables."""
|
||||
pass
|
||||
|
||||
class OpenAIStrategy(ProviderStrategy):
|
||||
"""OpenAI provider validation strategy."""
|
||||
|
||||
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||
"""Validate OpenAI environment variables."""
|
||||
missing = []
|
||||
|
||||
# Check if we're validating expert config
|
||||
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai':
|
||||
key = os.environ.get('EXPERT_OPENAI_API_KEY')
|
||||
if not key or key == '':
|
||||
# Try to copy from base if not set
|
||||
base_key = os.environ.get('OPENAI_API_KEY')
|
||||
if base_key:
|
||||
os.environ['EXPERT_OPENAI_API_KEY'] = base_key
|
||||
key = base_key
|
||||
if not key:
|
||||
missing.append('EXPERT_OPENAI_API_KEY environment variable is not set')
|
||||
|
||||
# Check expert model only for research-only mode
|
||||
if hasattr(args, 'research_only') and args.research_only:
|
||||
model = args.expert_model if hasattr(args, 'expert_model') else None
|
||||
if not model:
|
||||
model = os.environ.get('EXPERT_OPENAI_MODEL')
|
||||
if not model:
|
||||
model = os.environ.get('OPENAI_MODEL')
|
||||
if not model:
|
||||
missing.append('Model is required for OpenAI provider in research-only mode')
|
||||
else:
|
||||
key = os.environ.get('OPENAI_API_KEY')
|
||||
if not key:
|
||||
missing.append('OPENAI_API_KEY environment variable is not set')
|
||||
|
||||
# Check model only for research-only mode
|
||||
if hasattr(args, 'research_only') and args.research_only:
|
||||
model = args.model if hasattr(args, 'model') else None
|
||||
if not model:
|
||||
model = os.environ.get('OPENAI_MODEL')
|
||||
if not model:
|
||||
missing.append('Model is required for OpenAI provider in research-only mode')
|
||||
|
||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
class OpenAICompatibleStrategy(ProviderStrategy):
|
||||
"""OpenAI-compatible provider validation strategy."""
|
||||
|
||||
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||
"""Validate OpenAI-compatible environment variables."""
|
||||
missing = []
|
||||
|
||||
# Check if we're validating expert config
|
||||
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai-compatible':
|
||||
key = os.environ.get('EXPERT_OPENAI_API_KEY')
|
||||
base = os.environ.get('EXPERT_OPENAI_API_BASE')
|
||||
|
||||
# Try to copy from base if not set
|
||||
if not key or key == '':
|
||||
base_key = os.environ.get('OPENAI_API_KEY')
|
||||
if base_key:
|
||||
os.environ['EXPERT_OPENAI_API_KEY'] = base_key
|
||||
key = base_key
|
||||
if not base or base == '':
|
||||
base_base = os.environ.get('OPENAI_API_BASE')
|
||||
if base_base:
|
||||
os.environ['EXPERT_OPENAI_API_BASE'] = base_base
|
||||
base = base_base
|
||||
|
||||
if not key:
|
||||
missing.append('EXPERT_OPENAI_API_KEY environment variable is not set')
|
||||
if not base:
|
||||
missing.append('EXPERT_OPENAI_API_BASE environment variable is not set')
|
||||
|
||||
# Check expert model only for research-only mode
|
||||
if hasattr(args, 'research_only') and args.research_only:
|
||||
model = args.expert_model if hasattr(args, 'expert_model') else None
|
||||
if not model:
|
||||
model = os.environ.get('EXPERT_OPENAI_MODEL')
|
||||
if not model:
|
||||
model = os.environ.get('OPENAI_MODEL')
|
||||
if not model:
|
||||
missing.append('Model is required for OpenAI-compatible provider in research-only mode')
|
||||
else:
|
||||
key = os.environ.get('OPENAI_API_KEY')
|
||||
base = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
if not key:
|
||||
missing.append('OPENAI_API_KEY environment variable is not set')
|
||||
if not base:
|
||||
missing.append('OPENAI_API_BASE environment variable is not set')
|
||||
|
||||
# Check model only for research-only mode
|
||||
if hasattr(args, 'research_only') and args.research_only:
|
||||
model = args.model if hasattr(args, 'model') else None
|
||||
if not model:
|
||||
model = os.environ.get('OPENAI_MODEL')
|
||||
if not model:
|
||||
missing.append('Model is required for OpenAI-compatible provider in research-only mode')
|
||||
|
||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
class AnthropicStrategy(ProviderStrategy):
|
||||
"""Anthropic provider validation strategy."""
|
||||
|
||||
VALID_MODELS = [
|
||||
"claude-"
|
||||
]
|
||||
|
||||
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||
"""Validate Anthropic environment variables and model."""
|
||||
missing = []
|
||||
|
||||
# Check if we're validating expert config
|
||||
is_expert = args and hasattr(args, 'expert_provider') and args.expert_provider == 'anthropic'
|
||||
|
||||
# Check API key
|
||||
if is_expert:
|
||||
key = os.environ.get('EXPERT_ANTHROPIC_API_KEY')
|
||||
if not key or key == '':
|
||||
# Try to copy from base if not set
|
||||
base_key = os.environ.get('ANTHROPIC_API_KEY')
|
||||
if base_key:
|
||||
os.environ['EXPERT_ANTHROPIC_API_KEY'] = base_key
|
||||
key = base_key
|
||||
if not key:
|
||||
missing.append('EXPERT_ANTHROPIC_API_KEY environment variable is not set')
|
||||
else:
|
||||
key = os.environ.get('ANTHROPIC_API_KEY')
|
||||
if not key:
|
||||
missing.append('ANTHROPIC_API_KEY environment variable is not set')
|
||||
|
||||
# Check model
|
||||
model_matched = False
|
||||
model_to_check = None
|
||||
|
||||
# First check command line argument
|
||||
if is_expert:
|
||||
if hasattr(args, 'expert_model') and args.expert_model:
|
||||
model_to_check = args.expert_model
|
||||
else:
|
||||
# If no expert model, check environment variable
|
||||
model_to_check = os.environ.get('EXPERT_ANTHROPIC_MODEL')
|
||||
if not model_to_check or model_to_check == '':
|
||||
# Try to copy from base if not set
|
||||
base_model = os.environ.get('ANTHROPIC_MODEL')
|
||||
if base_model:
|
||||
os.environ['EXPERT_ANTHROPIC_MODEL'] = base_model
|
||||
model_to_check = base_model
|
||||
else:
|
||||
if hasattr(args, 'model') and args.model:
|
||||
model_to_check = args.model
|
||||
else:
|
||||
model_to_check = os.environ.get('ANTHROPIC_MODEL')
|
||||
|
||||
if not model_to_check:
|
||||
missing.append('ANTHROPIC_MODEL environment variable is not set')
|
||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
# Validate model format
|
||||
for pattern in self.VALID_MODELS:
|
||||
if re.match(pattern, model_to_check):
|
||||
model_matched = True
|
||||
break
|
||||
|
||||
if not model_matched:
|
||||
missing.append(f'Invalid Anthropic model: {model_to_check}. Must match one of these patterns: {", ".join(self.VALID_MODELS)}')
|
||||
|
||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
class OpenRouterStrategy(ProviderStrategy):
|
||||
"""OpenRouter provider validation strategy."""
|
||||
|
||||
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||
"""Validate OpenRouter environment variables."""
|
||||
missing = []
|
||||
|
||||
# Check if we're validating expert config
|
||||
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openrouter':
|
||||
key = os.environ.get('EXPERT_OPENROUTER_API_KEY')
|
||||
if not key or key == '':
|
||||
# Try to copy from base if not set
|
||||
base_key = os.environ.get('OPENROUTER_API_KEY')
|
||||
if base_key:
|
||||
os.environ['EXPERT_OPENROUTER_API_KEY'] = base_key
|
||||
key = base_key
|
||||
if not key:
|
||||
missing.append('EXPERT_OPENROUTER_API_KEY environment variable is not set')
|
||||
else:
|
||||
key = os.environ.get('OPENROUTER_API_KEY')
|
||||
if not key:
|
||||
missing.append('OPENROUTER_API_KEY environment variable is not set')
|
||||
|
||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
class ProviderFactory:
|
||||
"""Factory for creating provider validation strategies."""
|
||||
|
||||
@staticmethod
|
||||
def create(provider: str, args: Optional[Any] = None) -> Optional[ProviderStrategy]:
|
||||
"""Create a provider validation strategy.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
args: Optional command line arguments
|
||||
|
||||
Returns:
|
||||
Provider validation strategy or None if provider not found
|
||||
"""
|
||||
strategies = {
|
||||
'openai': OpenAIStrategy(),
|
||||
'openai-compatible': OpenAICompatibleStrategy(),
|
||||
'anthropic': AnthropicStrategy(),
|
||||
'openrouter': OpenRouterStrategy()
|
||||
}
|
||||
strategy = strategies.get(provider)
|
||||
return strategy
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Test package for RA.Aid."""
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Unit tests for environment validation."""
|
||||
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from ra_aid.env import validate_environment
|
||||
|
||||
@dataclass
|
||||
class MockArgs:
|
||||
"""Mock arguments for testing."""
|
||||
research_only: bool
|
||||
provider: str
|
||||
expert_provider: str = None
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
"research_only_no_model",
|
||||
MockArgs(research_only=True, provider="openai"),
|
||||
(False, [], False, ["TAVILY_API_KEY environment variable is not set"]),
|
||||
{},
|
||||
id="research_only_no_model"
|
||||
),
|
||||
pytest.param(
|
||||
"research_only_with_model",
|
||||
MockArgs(research_only=True, provider="openai"),
|
||||
(False, [], True, []),
|
||||
{"TAVILY_API_KEY": "test_key"},
|
||||
id="research_only_with_model"
|
||||
)
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("test_name,args,expected,env_vars", TEST_CASES)
|
||||
def test_validate_environment_research_only(
|
||||
test_name: str,
|
||||
args: MockArgs,
|
||||
expected: tuple,
|
||||
env_vars: dict,
|
||||
monkeypatch
|
||||
):
|
||||
"""Test validate_environment with research_only flag."""
|
||||
# Clear any existing environment variables
|
||||
monkeypatch.delenv("TAVILY_API_KEY", raising=False)
|
||||
|
||||
# Set test environment variables
|
||||
for key, value in env_vars.items():
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
result = validate_environment(args)
|
||||
assert result == expected, f"Failed test case: {test_name}"
|
||||
|
|
@ -13,11 +13,11 @@ from ra_aid.tools.agent import request_research, request_implementation, request
|
|||
# Read-only tools that don't modify system state
|
||||
def get_read_only_tools(human_interaction: bool = False, web_research_enabled: bool = False) -> list:
|
||||
"""Get the list of read-only tools, optionally including human interaction tools.
|
||||
|
||||
|
||||
Args:
|
||||
human_interaction: Whether to include human interaction tools
|
||||
web_research_enabled: Whether to include web research tools
|
||||
|
||||
|
||||
Returns:
|
||||
List of tool functions
|
||||
"""
|
||||
|
|
@ -37,10 +37,10 @@ def get_read_only_tools(human_interaction: bool = False, web_research_enabled: b
|
|||
|
||||
if web_research_enabled:
|
||||
tools.append(request_web_research)
|
||||
|
||||
|
||||
if human_interaction:
|
||||
tools.append(ask_human)
|
||||
|
||||
|
||||
return tools
|
||||
|
||||
# Define constant tool groups
|
||||
|
|
@ -58,7 +58,7 @@ RESEARCH_TOOLS = [
|
|||
|
||||
def get_research_tools(research_only: bool = False, expert_enabled: bool = True, human_interaction: bool = False, web_research_enabled: bool = False) -> list:
|
||||
"""Get the list of research tools based on mode and whether expert is enabled.
|
||||
|
||||
|
||||
Args:
|
||||
research_only: Whether to exclude modification tools
|
||||
expert_enabled: Whether to include expert tools
|
||||
|
|
@ -67,33 +67,33 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True,
|
|||
"""
|
||||
# Start with read-only tools
|
||||
tools = get_read_only_tools(human_interaction, web_research_enabled).copy()
|
||||
|
||||
|
||||
tools.extend(RESEARCH_TOOLS)
|
||||
|
||||
|
||||
# Add modification tools if not research_only
|
||||
if not research_only:
|
||||
tools.extend(MODIFICATION_TOOLS)
|
||||
tools.append(request_implementation)
|
||||
|
||||
|
||||
# Add expert tools if enabled
|
||||
if expert_enabled:
|
||||
tools.extend(EXPERT_TOOLS)
|
||||
|
||||
|
||||
# Add chat-specific tools
|
||||
tools.append(request_research)
|
||||
|
||||
|
||||
return tools
|
||||
|
||||
def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
||||
"""Get the list of planning tools based on whether expert is enabled.
|
||||
|
||||
|
||||
Args:
|
||||
expert_enabled: Whether to include expert tools
|
||||
web_research_enabled: Whether to include web research tools
|
||||
"""
|
||||
# Start with read-only tools
|
||||
tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy()
|
||||
|
||||
|
||||
# Add planning-specific tools
|
||||
planning_tools = [
|
||||
emit_plan,
|
||||
|
|
@ -101,41 +101,43 @@ def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool =
|
|||
plan_implementation_completed
|
||||
]
|
||||
tools.extend(planning_tools)
|
||||
|
||||
|
||||
# Add expert tools if enabled
|
||||
if expert_enabled:
|
||||
tools.extend(EXPERT_TOOLS)
|
||||
|
||||
|
||||
return tools
|
||||
|
||||
def get_implementation_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
||||
"""Get the list of implementation tools based on whether expert is enabled.
|
||||
|
||||
|
||||
Args:
|
||||
expert_enabled: Whether to include expert tools
|
||||
web_research_enabled: Whether to include web research tools
|
||||
"""
|
||||
# Start with read-only tools
|
||||
tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy()
|
||||
|
||||
|
||||
# Add modification tools since it's not research-only
|
||||
tools.extend(MODIFICATION_TOOLS)
|
||||
tools.extend([
|
||||
task_completed
|
||||
])
|
||||
|
||||
|
||||
# Add expert tools if enabled
|
||||
if expert_enabled:
|
||||
tools.extend(EXPERT_TOOLS)
|
||||
|
||||
|
||||
return tools
|
||||
|
||||
def get_web_research_tools(expert_enabled: bool = True) -> list:
|
||||
"""Get the list of tools available for web research.
|
||||
|
||||
|
||||
Args:
|
||||
expert_enabled: Whether expert tools should be included
|
||||
|
||||
human_interaction: Whether to include human interaction tools
|
||||
web_research_enabled: Whether to include web research tools
|
||||
|
||||
Returns:
|
||||
list: List of tools configured for web research
|
||||
"""
|
||||
|
|
@ -153,10 +155,10 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
|
|||
|
||||
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
||||
"""Get the list of tools available in chat mode.
|
||||
|
||||
|
||||
Chat mode includes research and implementation capabilities but excludes
|
||||
complex planning tools. Human interaction is always enabled.
|
||||
|
||||
|
||||
Args:
|
||||
expert_enabled: Whether to include expert tools
|
||||
web_research_enabled: Whether to include web research tools
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
"""Tests for default provider and model configuration."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.__main__ import parse_arguments
|
||||
|
||||
@dataclass
|
||||
class MockArgs:
|
||||
"""Mock arguments for testing."""
|
||||
provider: str
|
||||
expert_provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
expert_model: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
research_only: bool = False
|
||||
chat: bool = False
|
||||
|
||||
@pytest.fixture
|
||||
def clean_env(monkeypatch):
|
||||
"""Remove all provider-related environment variables."""
|
||||
env_vars = [
|
||||
"ANTHROPIC_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_BASE",
|
||||
"EXPERT_ANTHROPIC_API_KEY",
|
||||
"EXPERT_OPENAI_API_KEY",
|
||||
"EXPERT_OPENAI_API_BASE",
|
||||
"TAVILY_API_KEY",
|
||||
"ANTHROPIC_MODEL",
|
||||
]
|
||||
for var in env_vars:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
yield
|
||||
|
||||
def test_default_anthropic_provider(clean_env, monkeypatch):
|
||||
"""Test that Anthropic is the default provider when no environment variables are set."""
|
||||
args = parse_arguments(["-m", "test message"])
|
||||
assert args.provider == "anthropic"
|
||||
assert args.model == "claude-3-5-sonnet-20241022"
|
||||
|
||||
|
||||
"""Unit tests for provider and model validation in research-only mode."""
|
||||
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
from argparse import Namespace
|
||||
from ra_aid.env import validate_environment
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockArgs:
|
||||
"""Mock command line arguments."""
|
||||
research_only: bool = False
|
||||
provider: str = None
|
||||
model: str = None
|
||||
expert_provider: str = None
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
"research_only_no_provider",
|
||||
MockArgs(research_only=True),
|
||||
{},
|
||||
"No provider specified",
|
||||
id="research_only_no_provider"
|
||||
),
|
||||
pytest.param(
|
||||
"research_only_anthropic",
|
||||
MockArgs(research_only=True, provider="anthropic"),
|
||||
{},
|
||||
None,
|
||||
id="research_only_anthropic"
|
||||
),
|
||||
pytest.param(
|
||||
"research_only_non_anthropic_no_model",
|
||||
MockArgs(research_only=True, provider="openai"),
|
||||
{},
|
||||
"Model is required for non-Anthropic providers",
|
||||
id="research_only_non_anthropic_no_model"
|
||||
),
|
||||
pytest.param(
|
||||
"research_only_non_anthropic_with_model",
|
||||
MockArgs(research_only=True, provider="openai", model="gpt-4"),
|
||||
{},
|
||||
None,
|
||||
id="research_only_non_anthropic_with_model"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_name,args,env_vars,expected_error", TEST_CASES)
|
||||
def test_research_only_provider_validation(
|
||||
test_name: str,
|
||||
args: MockArgs,
|
||||
env_vars: dict,
|
||||
expected_error: str,
|
||||
monkeypatch
|
||||
):
|
||||
"""Test provider and model validation in research-only mode."""
|
||||
# Set test environment variables
|
||||
for key, value in env_vars.items():
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
if expected_error:
|
||||
with pytest.raises(SystemExit, match=expected_error):
|
||||
validate_environment(args)
|
||||
else:
|
||||
validate_environment(args)
|
||||
|
|
@ -18,33 +18,32 @@ def clean_env(monkeypatch):
|
|||
env_vars = [
|
||||
'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY',
|
||||
'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY',
|
||||
'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY'
|
||||
'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY', 'ANTHROPIC_MODEL'
|
||||
]
|
||||
for var in env_vars:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
def test_anthropic_validation(clean_env, monkeypatch):
|
||||
args = MockArgs(provider="anthropic", expert_provider="openai")
|
||||
|
||||
args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307")
|
||||
|
||||
# Should fail without API key
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
# Should pass with API key
|
||||
|
||||
# Should pass with API key and model
|
||||
monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key')
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
assert not expert_enabled
|
||||
assert 'EXPERT_OPENAI_API_KEY environment variable is not set' in expert_missing
|
||||
assert not web_research_enabled
|
||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||
|
||||
def test_openai_validation(clean_env, monkeypatch):
|
||||
args = MockArgs(provider="openai", expert_provider="openai")
|
||||
|
||||
|
||||
# Should fail without API key
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
|
||||
# Should pass with API key and enable expert mode with fallback
|
||||
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
|
|
@ -56,16 +55,16 @@ def test_openai_validation(clean_env, monkeypatch):
|
|||
|
||||
def test_openai_compatible_validation(clean_env, monkeypatch):
|
||||
args = MockArgs(provider="openai-compatible", expert_provider="openai-compatible")
|
||||
|
||||
|
||||
# Should fail without API key and base URL
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
|
||||
# Should fail with only API key
|
||||
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
|
||||
# Should pass with both API key and base URL
|
||||
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
|
|
@ -78,10 +77,10 @@ def test_openai_compatible_validation(clean_env, monkeypatch):
|
|||
|
||||
def test_expert_fallback(clean_env, monkeypatch):
|
||||
args = MockArgs(provider="openai", expert_provider="openai")
|
||||
|
||||
|
||||
# Set only base API key
|
||||
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
||||
|
||||
|
||||
# Should enable expert mode with fallback
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
assert expert_enabled
|
||||
|
|
@ -89,7 +88,7 @@ def test_expert_fallback(clean_env, monkeypatch):
|
|||
assert not web_research_enabled
|
||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
|
||||
|
||||
|
||||
# Should use explicit expert key if available
|
||||
monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key')
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
|
|
@ -101,25 +100,25 @@ def test_expert_fallback(clean_env, monkeypatch):
|
|||
|
||||
def test_cross_provider_fallback(clean_env, monkeypatch):
|
||||
"""Test that fallback works even when providers differ"""
|
||||
args = MockArgs(provider="openai", expert_provider="anthropic")
|
||||
|
||||
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
|
||||
|
||||
# Set base API key for main provider and expert provider
|
||||
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
||||
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
||||
|
||||
monkeypatch.setenv('ANTHROPIC_MODEL', 'claude-3-haiku-20240307')
|
||||
|
||||
# Should enable expert mode with fallback to ANTHROPIC base key
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
assert not web_research_enabled
|
||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||
assert os.environ.get('EXPERT_ANTHROPIC_API_KEY') == 'anthropic-key'
|
||||
|
||||
# Try with openai-compatible expert provider
|
||||
args = MockArgs(provider="anthropic", expert_provider="openai-compatible")
|
||||
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
||||
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
||||
|
||||
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
|
|
@ -131,43 +130,51 @@ def test_cross_provider_fallback(clean_env, monkeypatch):
|
|||
def test_no_warning_on_fallback(clean_env, monkeypatch):
|
||||
"""Test that no warning is issued when fallback succeeds"""
|
||||
args = MockArgs(provider="openai", expert_provider="openai")
|
||||
|
||||
|
||||
# Set only base API key
|
||||
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
||||
|
||||
|
||||
# Should enable expert mode with fallback and no warnings
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing # List should be empty
|
||||
assert not expert_missing
|
||||
assert not web_research_enabled
|
||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
|
||||
|
||||
# Should use explicit expert key if available
|
||||
monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key')
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
assert not web_research_enabled
|
||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'expert-key'
|
||||
|
||||
def test_different_providers_no_expert_key(clean_env, monkeypatch):
|
||||
"""Test behavior when providers differ and only base keys are available"""
|
||||
args = MockArgs(provider="anthropic", expert_provider="openai")
|
||||
|
||||
args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307")
|
||||
|
||||
# Set only base keys
|
||||
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
||||
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
||||
|
||||
|
||||
# Should enable expert mode and use base OPENAI key
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
assert not web_research_enabled
|
||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'openai-key'
|
||||
|
||||
def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
|
||||
"""Test behavior with openai-compatible expert and different main provider"""
|
||||
args = MockArgs(provider="anthropic", expert_provider="openai-compatible")
|
||||
|
||||
args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")
|
||||
|
||||
# Set all required keys and URLs
|
||||
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
||||
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
||||
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
||||
|
||||
|
||||
# Should enable expert mode and use base openai key and URL
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||
assert expert_enabled
|
||||
|
|
|
|||
|
|
@ -207,14 +207,13 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
|
|||
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
|
||||
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
|
||||
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation
|
||||
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
assert not expert_enabled
|
||||
assert len(expert_missing) == 1
|
||||
assert expert_missing[0] == "EXPERT_OPENAI_API_KEY environment variable is not set"
|
||||
assert expert_missing
|
||||
assert not web_enabled
|
||||
assert len(web_missing) == 1
|
||||
assert web_missing[0] == "TAVILY_API_KEY environment variable is not set"
|
||||
assert web_missing
|
||||
|
||||
@pytest.fixture
|
||||
def mock_anthropic():
|
||||
|
|
|
|||
|
|
@ -0,0 +1,259 @@
|
|||
"""Integration tests for provider validation and environment handling."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.provider_strategy import (
|
||||
ProviderFactory,
|
||||
ValidationResult,
|
||||
AnthropicStrategy,
|
||||
OpenAIStrategy,
|
||||
OpenAICompatibleStrategy,
|
||||
OpenRouterStrategy
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class MockArgs:
|
||||
"""Mock arguments for testing."""
|
||||
provider: str
|
||||
expert_provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
expert_model: Optional[str] = None
|
||||
|
||||
@pytest.fixture
|
||||
def clean_env():
|
||||
"""Remove all provider-related environment variables."""
|
||||
env_vars = [
|
||||
"ANTHROPIC_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_BASE",
|
||||
"EXPERT_ANTHROPIC_API_KEY",
|
||||
"EXPERT_OPENAI_API_KEY",
|
||||
"EXPERT_OPENAI_API_BASE",
|
||||
"TAVILY_API_KEY",
|
||||
"ANTHROPIC_MODEL"
|
||||
]
|
||||
|
||||
# Store original values
|
||||
original_values = {}
|
||||
for var in env_vars:
|
||||
original_values[var] = os.environ.get(var)
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
yield
|
||||
|
||||
# Restore original values
|
||||
for var, value in original_values.items():
|
||||
if value is not None:
|
||||
os.environ[var] = value
|
||||
elif var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
def test_provider_validation_respects_cli_args(clean_env):
|
||||
"""Test that provider validation respects CLI args over defaults."""
|
||||
# Set up environment with only OpenAI credentials
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
|
||||
# Should succeed with OpenAI provider
|
||||
args = MockArgs(provider="openai")
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
assert not expert_missing
|
||||
|
||||
# Should fail with Anthropic provider even though it's first alphabetically
|
||||
args = MockArgs(provider="anthropic", model="claude-3-haiku-20240307")
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
def test_expert_provider_fallback(clean_env):
|
||||
"""Test expert provider falls back to main provider keys."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
args = MockArgs(provider="openai", expert_provider="openai")
|
||||
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
|
||||
def test_openai_compatible_base_url(clean_env):
|
||||
"""Test OpenAI-compatible provider requires base URL."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
args = MockArgs(provider="openai-compatible")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
os.environ["OPENAI_API_BASE"] = "http://test"
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert not expert_missing
|
||||
|
||||
def test_expert_provider_separate_keys(clean_env):
|
||||
"""Test expert provider can use separate keys."""
|
||||
os.environ["OPENAI_API_KEY"] = "main-key"
|
||||
os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key"
|
||||
|
||||
args = MockArgs(provider="openai", expert_provider="openai")
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
|
||||
def test_web_research_independent(clean_env):
|
||||
"""Test web research validation is independent of provider."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
args = MockArgs(provider="openai")
|
||||
|
||||
# Without Tavily key
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
assert not web_enabled
|
||||
assert web_missing
|
||||
|
||||
# With Tavily key
|
||||
os.environ["TAVILY_API_KEY"] = "test-key"
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
assert web_enabled
|
||||
assert not web_missing
|
||||
|
||||
def test_provider_factory_unknown_provider(clean_env):
|
||||
"""Test provider factory handles unknown providers."""
|
||||
strategy = ProviderFactory.create("unknown")
|
||||
assert strategy is None
|
||||
|
||||
args = MockArgs(provider="unknown")
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
def test_provider_strategy_validation(clean_env):
|
||||
"""Test individual provider strategies."""
|
||||
# Test Anthropic strategy
|
||||
strategy = AnthropicStrategy()
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "ANTHROPIC_API_KEY environment variable is not set" in result.missing_vars
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test-key"
|
||||
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
|
||||
result = strategy.validate()
|
||||
assert result.valid
|
||||
assert not result.missing_vars
|
||||
|
||||
# Test OpenAI strategy
|
||||
strategy = OpenAIStrategy()
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
result = strategy.validate()
|
||||
assert result.valid
|
||||
assert not result.missing_vars
|
||||
|
||||
def test_missing_provider_arg():
|
||||
"""Test handling of missing provider argument."""
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(None)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(MockArgs(provider=None))
|
||||
|
||||
def test_empty_provider_arg():
|
||||
"""Test handling of empty provider argument."""
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(MockArgs(provider=""))
|
||||
|
||||
def test_incomplete_openai_compatible_config(clean_env):
|
||||
"""Test OpenAI-compatible provider with incomplete configuration."""
|
||||
strategy = OpenAICompatibleStrategy()
|
||||
|
||||
# No configuration
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
||||
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
||||
|
||||
# Only API key
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
||||
|
||||
# Only base URL
|
||||
os.environ.pop("OPENAI_API_KEY")
|
||||
os.environ["OPENAI_API_BASE"] = "http://test"
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
||||
|
||||
def test_incomplete_expert_config(clean_env):
|
||||
"""Test expert provider with incomplete configuration."""
|
||||
# Set main provider but not expert
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
||||
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert not expert_enabled
|
||||
assert len(expert_missing) == 1
|
||||
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
|
||||
|
||||
# Set expert key but not base URL
|
||||
os.environ["EXPERT_OPENAI_API_KEY"] = "test-key"
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert not expert_enabled
|
||||
assert len(expert_missing) == 1
|
||||
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
|
||||
|
||||
def test_empty_environment_variables(clean_env):
|
||||
"""Test handling of empty environment variables."""
|
||||
# Empty API key
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
args = MockArgs(provider="openai")
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
# Empty base URL
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_BASE"] = ""
|
||||
args = MockArgs(provider="openai-compatible")
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
def test_openrouter_validation(clean_env):
|
||||
"""Test OpenRouter provider validation."""
|
||||
strategy = OpenRouterStrategy()
|
||||
|
||||
# No API key
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
|
||||
|
||||
# Empty API key
|
||||
os.environ["OPENROUTER_API_KEY"] = ""
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
|
||||
|
||||
# Valid API key
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-key"
|
||||
result = strategy.validate()
|
||||
assert result.valid
|
||||
assert not result.missing_vars
|
||||
|
||||
def test_multiple_expert_providers(clean_env):
|
||||
"""Test validation with multiple expert providers."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["ANTHROPIC_API_KEY"] = "test-key"
|
||||
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
|
||||
|
||||
# First expert provider valid, second invalid
|
||||
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
|
||||
# Switch to invalid provider
|
||||
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert not expert_enabled
|
||||
assert expert_missing
|
||||
Loading…
Reference in New Issue