FEAT fix command line args and env var dependencies on anthropic (#21)
* FIX provider cmdline args * FIX Issue 18 * FIX ensure research-only requires a model
This commit is contained in:
parent
9944ec9ea4
commit
8b3f4d736c
|
|
@ -5,13 +5,12 @@ All notable changes to this project will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [0.10.3] - 1024-12-27
|
## [0.10.3] - 2024-12-27
|
||||||
|
|
||||||
- Fix logging on interrupt.
|
- Fix logging on interrupt.
|
||||||
- Fix web research prompt.
|
- Fix web research prompt.
|
||||||
- Simplify planning stage by executing tasks directly.
|
- Simplify planning stage by executing tasks directly.
|
||||||
- Make research notes available to more agents/tools.
|
- Make research notes available to more agents/tools.
|
||||||
- Make read_file always output status panel.
|
|
||||||
|
|
||||||
## [0.10.2] - 2024-12-26
|
## [0.10.2] - 2024-12-26
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
.PHONY: test setup-dev setup-hooks
|
||||||
|
|
||||||
|
test:
|
||||||
|
python -m pytest
|
||||||
|
|
||||||
|
setup-dev:
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
|
setup-hooks: setup-dev
|
||||||
|
pre-commit install
|
||||||
|
|
@ -13,11 +13,11 @@ keywords = ["langchain", "ai", "agent", "tools", "development"]
|
||||||
authors = [{name = "AI Christianson", email = "ai.christianson@christianson.ai"}]
|
authors = [{name = "AI Christianson", email = "ai.christianson@christianson.ai"}]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules"
|
"Topic :: Software Development :: Libraries :: Python Modules"
|
||||||
]
|
]
|
||||||
|
|
@ -31,7 +31,7 @@ dependencies = [
|
||||||
"langchain>=0.3.13",
|
"langchain>=0.3.13",
|
||||||
"rich>=13.0.0",
|
"rich>=13.0.0",
|
||||||
"GitPython>=3.1",
|
"GitPython>=3.1",
|
||||||
"fuzzywuzzy==0.18.0",
|
"fuzzywuzzy==0.18.0",
|
||||||
"python-Levenshtein==0.23.0",
|
"python-Levenshtein==0.23.0",
|
||||||
"pathspec>=0.11.0",
|
"pathspec>=0.11.0",
|
||||||
"aider-chat>=0.69.1",
|
"aider-chat>=0.69.1",
|
||||||
|
|
|
||||||
|
|
@ -26,10 +26,14 @@ from ra_aid.logging_config import setup_logging, get_logger
|
||||||
from ra_aid.tool_configs import (
|
from ra_aid.tool_configs import (
|
||||||
get_chat_tools
|
get_chat_tools
|
||||||
)
|
)
|
||||||
|
import os
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments(args=None):
|
||||||
|
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible']
|
||||||
|
ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022'
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='RA.Aid - AI Agent for executing programming and research tasks',
|
description='RA.Aid - AI Agent for executing programming and research tasks',
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
|
@ -59,7 +63,7 @@ Examples:
|
||||||
'--provider',
|
'--provider',
|
||||||
type=str,
|
type=str,
|
||||||
default='anthropic',
|
default='anthropic',
|
||||||
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'],
|
choices=VALID_PROVIDERS,
|
||||||
help='The LLM provider to use'
|
help='The LLM provider to use'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -76,7 +80,7 @@ Examples:
|
||||||
'--expert-provider',
|
'--expert-provider',
|
||||||
type=str,
|
type=str,
|
||||||
default='openai',
|
default='openai',
|
||||||
choices=['anthropic', 'openai', 'openrouter', 'openai-compatible'],
|
choices=VALID_PROVIDERS,
|
||||||
help='The LLM provider to use for expert knowledge queries (default: openai)'
|
help='The LLM provider to use for expert knowledge queries (default: openai)'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -99,25 +103,32 @@ Examples:
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Enable verbose logging output'
|
help='Enable verbose logging output'
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
if args is None:
|
||||||
|
args = sys.argv[1:]
|
||||||
|
parsed_args = parser.parse_args(args)
|
||||||
|
|
||||||
# Set hil=True when chat mode is enabled
|
# Set hil=True when chat mode is enabled
|
||||||
if args.chat:
|
if parsed_args.chat:
|
||||||
args.hil = True
|
parsed_args.hil = True
|
||||||
|
|
||||||
# Set default model for Anthropic, require model for other providers
|
# Validate provider
|
||||||
if args.provider == 'anthropic':
|
if parsed_args.provider not in VALID_PROVIDERS:
|
||||||
if not args.model:
|
parser.error(f"Invalid provider: {parsed_args.provider}")
|
||||||
args.model = 'claude-3-5-sonnet-20241022'
|
|
||||||
elif not args.model:
|
# Handle model defaults and requirements
|
||||||
parser.error(f"--model is required when using provider '{args.provider}'")
|
if parsed_args.provider == 'anthropic':
|
||||||
|
# Always use default model for Anthropic
|
||||||
|
parsed_args.model = ANTHROPIC_DEFAULT_MODEL
|
||||||
|
elif not parsed_args.model and not parsed_args.research_only:
|
||||||
|
# Require model for other providers unless in research mode
|
||||||
|
parser.error(f"--model is required when using provider '{parsed_args.provider}'")
|
||||||
|
|
||||||
# Validate expert model requirement
|
# Validate expert model requirement
|
||||||
if args.expert_provider != 'openai' and not args.expert_model:
|
if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only:
|
||||||
parser.error(f"--expert-model is required when using expert provider '{args.expert_provider}'")
|
parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'")
|
||||||
|
|
||||||
return args
|
return parsed_args
|
||||||
|
|
||||||
# Create console instance
|
# Create console instance
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
@ -143,14 +154,14 @@ def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
setup_logging(args.verbose)
|
setup_logging(args.verbose)
|
||||||
logger.debug("Starting RA.Aid with arguments: %s", args)
|
logger.debug("Starting RA.Aid with arguments: %s", args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing
|
||||||
logger.debug("Environment validation successful")
|
logger.debug("Environment validation successful")
|
||||||
|
|
||||||
if expert_missing:
|
if expert_missing:
|
||||||
console.print(Panel(
|
console.print(Panel(
|
||||||
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" +
|
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" +
|
||||||
"\n".join(f"- {m}" for m in expert_missing) +
|
"\n".join(f"- {m}" for m in expert_missing) +
|
||||||
"\nSet the required environment variables or args to enable expert mode.",
|
"\nSet the required environment variables or args to enable expert mode.",
|
||||||
title="Expert Tools Disabled",
|
title="Expert Tools Disabled",
|
||||||
|
|
@ -159,20 +170,24 @@ def main():
|
||||||
|
|
||||||
if web_research_missing:
|
if web_research_missing:
|
||||||
console.print(Panel(
|
console.print(Panel(
|
||||||
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" +
|
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" +
|
||||||
"\n".join(f"- {m}" for m in web_research_missing) +
|
"\n".join(f"- {m}" for m in web_research_missing) +
|
||||||
"\nSet the required environment variables to enable web research.",
|
"\nSet the required environment variables to enable web research.",
|
||||||
title="Web Research Disabled",
|
title="Web Research Disabled",
|
||||||
style="yellow"
|
style="yellow"
|
||||||
))
|
))
|
||||||
|
|
||||||
# Create the base model after validation
|
# Create the base model after validation
|
||||||
model = initialize_llm(args.provider, args.model)
|
model = initialize_llm(args.provider, args.model)
|
||||||
|
|
||||||
# Handle chat mode
|
# Handle chat mode
|
||||||
if args.chat:
|
if args.chat:
|
||||||
|
if args.research_only:
|
||||||
|
print_error("Chat mode cannot be used with --research-only")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
print_stage_header("Chat Mode")
|
print_stage_header("Chat Mode")
|
||||||
|
|
||||||
# Get initial request from user
|
# Get initial request from user
|
||||||
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
||||||
|
|
||||||
|
|
@ -182,7 +197,7 @@ def main():
|
||||||
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
||||||
checkpointer=MemorySaver()
|
checkpointer=MemorySaver()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run chat agent with CHAT_PROMPT
|
# Run chat agent with CHAT_PROMPT
|
||||||
config = {
|
config = {
|
||||||
"configurable": {"thread_id": uuid.uuid4()},
|
"configurable": {"thread_id": uuid.uuid4()},
|
||||||
|
|
@ -193,14 +208,14 @@ def main():
|
||||||
"web_research_enabled": web_research_enabled,
|
"web_research_enabled": web_research_enabled,
|
||||||
"initial_request": initial_request
|
"initial_request": initial_request
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store config in global memory
|
# Store config in global memory
|
||||||
_global_memory['config'] = config
|
_global_memory['config'] = config
|
||||||
_global_memory['config']['provider'] = args.provider
|
_global_memory['config']['provider'] = args.provider
|
||||||
_global_memory['config']['model'] = args.model
|
_global_memory['config']['model'] = args.model
|
||||||
_global_memory['config']['expert_provider'] = args.expert_provider
|
_global_memory['config']['expert_provider'] = args.expert_provider
|
||||||
_global_memory['config']['expert_model'] = args.expert_model
|
_global_memory['config']['expert_model'] = args.expert_model
|
||||||
|
|
||||||
# Run chat agent and exit
|
# Run chat agent and exit
|
||||||
run_agent_with_retry(chat_agent, CHAT_PROMPT.format(
|
run_agent_with_retry(chat_agent, CHAT_PROMPT.format(
|
||||||
initial_request=initial_request,
|
initial_request=initial_request,
|
||||||
|
|
@ -212,7 +227,7 @@ def main():
|
||||||
if not args.message:
|
if not args.message:
|
||||||
print_error("--message is required")
|
print_error("--message is required")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
base_task = args.message
|
base_task = args.message
|
||||||
config = {
|
config = {
|
||||||
"configurable": {"thread_id": uuid.uuid4()},
|
"configurable": {"thread_id": uuid.uuid4()},
|
||||||
|
|
@ -221,21 +236,21 @@ def main():
|
||||||
"cowboy_mode": args.cowboy_mode,
|
"cowboy_mode": args.cowboy_mode,
|
||||||
"web_research_enabled": web_research_enabled
|
"web_research_enabled": web_research_enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store config in global memory for access by is_informational_query
|
# Store config in global memory for access by is_informational_query
|
||||||
_global_memory['config'] = config
|
_global_memory['config'] = config
|
||||||
|
|
||||||
# Store model configuration
|
# Store model configuration
|
||||||
_global_memory['config']['provider'] = args.provider
|
_global_memory['config']['provider'] = args.provider
|
||||||
_global_memory['config']['model'] = args.model
|
_global_memory['config']['model'] = args.model
|
||||||
|
|
||||||
# Store expert provider and model in config
|
# Store expert provider and model in config
|
||||||
_global_memory['config']['expert_provider'] = args.expert_provider
|
_global_memory['config']['expert_provider'] = args.expert_provider
|
||||||
_global_memory['config']['expert_model'] = args.expert_model
|
_global_memory['config']['expert_model'] = args.expert_model
|
||||||
|
|
||||||
# Run research stage
|
# Run research stage
|
||||||
print_stage_header("Research Stage")
|
print_stage_header("Research Stage")
|
||||||
|
|
||||||
run_research_agent(
|
run_research_agent(
|
||||||
base_task,
|
base_task,
|
||||||
model,
|
model,
|
||||||
|
|
@ -245,7 +260,7 @@ def main():
|
||||||
memory=research_memory,
|
memory=research_memory,
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Proceed with planning and implementation if not an informational query
|
# Proceed with planning and implementation if not an informational query
|
||||||
if not is_informational_query():
|
if not is_informational_query():
|
||||||
# Run planning agent
|
# Run planning agent
|
||||||
|
|
|
||||||
|
|
@ -162,9 +162,24 @@ def run_research_agent(
|
||||||
if console_message:
|
if console_message:
|
||||||
console.print(Panel(Markdown(console_message), title="🔬 Looking into it..."))
|
console.print(Panel(Markdown(console_message), title="🔬 Looking into it..."))
|
||||||
|
|
||||||
# Run agent with retry logic
|
# Run agent with retry logic if available
|
||||||
logger.debug("Research agent completed successfully")
|
if agent is not None:
|
||||||
return run_agent_with_retry(agent, prompt, run_config)
|
logger.debug("Research agent completed successfully")
|
||||||
|
return run_agent_with_retry(agent, prompt, run_config)
|
||||||
|
else:
|
||||||
|
# Just run web research tools directly if no agent
|
||||||
|
logger.debug("No model provided, running web research tools directly")
|
||||||
|
return run_web_research_agent(
|
||||||
|
base_task_or_query,
|
||||||
|
model=None,
|
||||||
|
expert_enabled=expert_enabled,
|
||||||
|
hil=hil,
|
||||||
|
web_research_enabled=web_research_enabled,
|
||||||
|
memory=memory,
|
||||||
|
config=config,
|
||||||
|
thread_id=thread_id,
|
||||||
|
console_message=console_message
|
||||||
|
)
|
||||||
except (KeyboardInterrupt, AgentInterrupt):
|
except (KeyboardInterrupt, AgentInterrupt):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -255,11 +270,41 @@ def run_web_research_agent(
|
||||||
try:
|
try:
|
||||||
# Display console message if provided
|
# Display console message if provided
|
||||||
if console_message:
|
if console_message:
|
||||||
console.print(Panel(Markdown(console_message), title="🔍 Starting Web Research..."))
|
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
|
||||||
|
|
||||||
|
# Run agent with retry logic if available
|
||||||
|
if agent is not None:
|
||||||
|
logger.debug("Web research agent completed successfully")
|
||||||
|
return run_agent_with_retry(agent, prompt, run_config)
|
||||||
|
else:
|
||||||
|
# Just use the web research tools directly
|
||||||
|
logger.debug("No model provided, using web research tools directly")
|
||||||
|
tavily_tool = next((tool for tool in tools if tool.name == 'web_search_tavily'), None)
|
||||||
|
|
||||||
|
if not tavily_tool:
|
||||||
|
return "No web research results found"
|
||||||
|
|
||||||
|
result = tavily_tool.invoke({"query": query})
|
||||||
|
if not result:
|
||||||
|
return "No web research results found"
|
||||||
|
|
||||||
|
# Format Tavily results
|
||||||
|
markdown_result = "# Search Results\n\n"
|
||||||
|
for item in result.get('results', []):
|
||||||
|
title = item.get('title', 'Untitled')
|
||||||
|
url = item.get('url', '')
|
||||||
|
content = item.get('content', '')
|
||||||
|
score = item.get('score', 0)
|
||||||
|
|
||||||
|
markdown_result += f"## {title}\n"
|
||||||
|
markdown_result += f"**Score**: {score:.2f}\n\n"
|
||||||
|
markdown_result += f"{content}\n\n"
|
||||||
|
markdown_result += f"[Read more]({url})\n\n"
|
||||||
|
markdown_result += "---\n\n"
|
||||||
|
|
||||||
|
console.print(Panel(Markdown(markdown_result), title="🔍 Web Research Results"))
|
||||||
|
return markdown_result
|
||||||
|
|
||||||
# Run agent with retry logic
|
|
||||||
logger.debug("Web research agent completed successfully")
|
|
||||||
return run_agent_with_retry(agent, prompt, run_config)
|
|
||||||
except (KeyboardInterrupt, AgentInterrupt):
|
except (KeyboardInterrupt, AgentInterrupt):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
270
ra_aid/env.py
270
ra_aid/env.py
|
|
@ -3,103 +3,209 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List, Any
|
||||||
|
|
||||||
from ra_aid import print_error
|
from ra_aid import print_error
|
||||||
|
from ra_aid.provider_strategy import ProviderFactory, ValidationResult
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProviderConfig:
|
class ValidationResult:
|
||||||
"""Configuration for a provider."""
|
"""Result of validation."""
|
||||||
key_name: str
|
valid: bool
|
||||||
base_required: bool = False
|
missing_vars: List[str]
|
||||||
|
|
||||||
PROVIDER_CONFIGS = {
|
def validate_provider(provider: str) -> ValidationResult:
|
||||||
"anthropic": ProviderConfig("ANTHROPIC_API_KEY", base_required=True),
|
"""Validate provider configuration."""
|
||||||
"openai": ProviderConfig("OPENAI_API_KEY", base_required=True),
|
if not provider:
|
||||||
"openrouter": ProviderConfig("OPENROUTER_API_KEY", base_required=True),
|
return ValidationResult(valid=False, missing_vars=["No provider specified"])
|
||||||
"openai-compatible": ProviderConfig("OPENAI_API_KEY", base_required=True),
|
strategy = ProviderFactory.create(provider)
|
||||||
}
|
if not strategy:
|
||||||
|
return ValidationResult(valid=False, missing_vars=[f"Unknown provider: {provider}"])
|
||||||
|
return strategy.validate()
|
||||||
|
|
||||||
|
def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
|
||||||
|
"""Copy base provider environment variables to expert provider if not set.
|
||||||
|
|
||||||
def validate_environment(args) -> Tuple[bool, List[str], bool, List[str]]:
|
|
||||||
"""Validate required environment variables and dependencies.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args: The parsed command line arguments containing:
|
base_provider: Base provider name
|
||||||
- provider: The main LLM provider
|
expert_provider: Expert provider name
|
||||||
- expert_provider: The expert LLM provider
|
"""
|
||||||
|
# Map of base to expert environment variables for each provider
|
||||||
|
provider_vars = {
|
||||||
|
'openai': {
|
||||||
|
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
|
||||||
|
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
|
||||||
|
},
|
||||||
|
'openai-compatible': {
|
||||||
|
'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY',
|
||||||
|
'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE'
|
||||||
|
},
|
||||||
|
'anthropic': {
|
||||||
|
'ANTHROPIC_API_KEY': 'EXPERT_ANTHROPIC_API_KEY',
|
||||||
|
'ANTHROPIC_MODEL': 'EXPERT_ANTHROPIC_MODEL'
|
||||||
|
},
|
||||||
|
'openrouter': {
|
||||||
|
'OPENROUTER_API_KEY': 'EXPERT_OPENROUTER_API_KEY'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get the variables to copy based on the expert provider
|
||||||
|
vars_to_copy = provider_vars.get(expert_provider, {})
|
||||||
|
for base_var, expert_var in vars_to_copy.items():
|
||||||
|
# Only copy if expert var is not set and base var exists
|
||||||
|
if not os.environ.get(expert_var) and os.environ.get(base_var):
|
||||||
|
os.environ[expert_var] = os.environ[base_var]
|
||||||
|
|
||||||
|
def validate_expert_provider(provider: str) -> ValidationResult:
|
||||||
|
"""Validate expert provider configuration with fallback."""
|
||||||
|
if not provider:
|
||||||
|
return ValidationResult(valid=True, missing_vars=[])
|
||||||
|
|
||||||
|
strategy = ProviderFactory.create(provider)
|
||||||
|
if not strategy:
|
||||||
|
return ValidationResult(valid=False, missing_vars=[f"Unknown expert provider: {provider}"])
|
||||||
|
|
||||||
|
# Copy base vars to expert vars for fallback
|
||||||
|
copy_base_to_expert_vars(provider, provider)
|
||||||
|
|
||||||
|
# Validate expert configuration
|
||||||
|
result = strategy.validate()
|
||||||
|
missing = []
|
||||||
|
|
||||||
|
for var in result.missing_vars:
|
||||||
|
key = var.split()[0] # Get the key name without the error message
|
||||||
|
expert_key = f"EXPERT_{key}"
|
||||||
|
if not os.environ.get(expert_key):
|
||||||
|
missing.append(f"{expert_key} environment variable is not set")
|
||||||
|
|
||||||
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
def validate_web_research() -> ValidationResult:
|
||||||
|
"""Validate web research configuration."""
|
||||||
|
key = "TAVILY_API_KEY"
|
||||||
|
return ValidationResult(
|
||||||
|
valid=bool(os.environ.get(key)),
|
||||||
|
missing_vars=[] if os.environ.get(key) else [f"{key} environment variable is not set"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_missing_dependencies(missing_vars: List[str]) -> None:
|
||||||
|
"""Print missing dependencies and exit."""
|
||||||
|
for var in missing_vars:
|
||||||
|
print(f"Error: {var}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def validate_research_only_provider(args: Any) -> None:
|
||||||
|
"""Validate provider and model for research-only mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Arguments containing provider and expert provider settings
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SystemExit: If provider or model validation fails
|
||||||
|
"""
|
||||||
|
# Get provider from args
|
||||||
|
provider = args.provider if args and hasattr(args, 'provider') else None
|
||||||
|
if not provider:
|
||||||
|
sys.exit("No provider specified")
|
||||||
|
|
||||||
|
# For non-Anthropic providers in research-only mode, model must be specified
|
||||||
|
if provider != 'anthropic':
|
||||||
|
model = args.model if hasattr(args, 'model') and args.model else None
|
||||||
|
if not model:
|
||||||
|
sys.exit("Model is required for non-Anthropic providers")
|
||||||
|
|
||||||
|
def validate_research_only(args: Any) -> tuple[bool, list[str], bool, list[str]]:
|
||||||
|
"""Validate environment variables for research-only mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Arguments containing provider and expert provider settings
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Tuple containing:
|
||||||
- bool: Whether expert mode is enabled
|
- expert_enabled: Whether expert mode is enabled
|
||||||
- List[str]: List of missing expert configuration items
|
- expert_missing: List of missing expert dependencies
|
||||||
- bool: Whether web research is enabled
|
- web_research_enabled: Whether web research is enabled
|
||||||
- List[str]: List of missing web research configuration items
|
- web_research_missing: List of missing web research dependencies
|
||||||
|
|
||||||
Raises:
|
|
||||||
SystemExit: If required base environment variables are missing
|
|
||||||
"""
|
"""
|
||||||
missing = []
|
# Initialize results
|
||||||
provider = args.provider
|
expert_enabled = False
|
||||||
expert_provider = args.expert_provider
|
|
||||||
|
|
||||||
# Check API keys based on provider configs
|
|
||||||
if provider in PROVIDER_CONFIGS:
|
|
||||||
config = PROVIDER_CONFIGS[provider]
|
|
||||||
if config.base_required and not os.environ.get(config.key_name):
|
|
||||||
missing.append(f'{config.key_name} environment variable is not set')
|
|
||||||
|
|
||||||
# Special case for openai-compatible needing base URL
|
|
||||||
if provider == "openai-compatible" and not os.environ.get('OPENAI_API_BASE'):
|
|
||||||
missing.append('OPENAI_API_BASE environment variable is not set')
|
|
||||||
|
|
||||||
expert_missing = []
|
expert_missing = []
|
||||||
if expert_provider in PROVIDER_CONFIGS:
|
|
||||||
config = PROVIDER_CONFIGS[expert_provider]
|
|
||||||
expert_key = f'EXPERT_{config.key_name}'
|
|
||||||
expert_key_missing = not os.environ.get(expert_key)
|
|
||||||
|
|
||||||
# Try fallback to base key for expert provider
|
|
||||||
fallback_available = os.environ.get(config.key_name)
|
|
||||||
if expert_key_missing and fallback_available:
|
|
||||||
os.environ[expert_key] = os.environ[config.key_name]
|
|
||||||
expert_key_missing = False
|
|
||||||
|
|
||||||
# Only add to missing list if still missing after fallback attempt
|
|
||||||
if expert_key_missing:
|
|
||||||
expert_missing.append(f'{expert_key} environment variable is not set')
|
|
||||||
|
|
||||||
# Special case for openai-compatible expert needing base URL
|
|
||||||
if expert_provider == "openai-compatible":
|
|
||||||
expert_base = 'EXPERT_OPENAI_API_BASE'
|
|
||||||
base_missing = not os.environ.get(expert_base)
|
|
||||||
base_fallback = os.environ.get('OPENAI_API_BASE')
|
|
||||||
|
|
||||||
if base_missing and base_fallback:
|
|
||||||
os.environ[expert_base] = os.environ['OPENAI_API_BASE']
|
|
||||||
base_missing = False
|
|
||||||
|
|
||||||
if base_missing:
|
|
||||||
expert_missing.append(f'{expert_base} environment variable is not set')
|
|
||||||
|
|
||||||
# If main keys missing, we must exit immediately
|
|
||||||
if missing:
|
|
||||||
print_error("Missing required dependencies:")
|
|
||||||
for item in missing:
|
|
||||||
print_error(f"- {item}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# If expert keys missing, we disable expert tools instead of exiting
|
|
||||||
expert_enabled = True
|
|
||||||
if expert_missing:
|
|
||||||
expert_enabled = False
|
|
||||||
|
|
||||||
# Check web research dependencies
|
|
||||||
web_research_missing = []
|
|
||||||
web_research_enabled = False
|
web_research_enabled = False
|
||||||
|
web_research_missing = []
|
||||||
if not os.environ.get('TAVILY_API_KEY'):
|
|
||||||
|
# Validate web research dependencies
|
||||||
|
tavily_key = os.environ.get('TAVILY_API_KEY')
|
||||||
|
if not tavily_key:
|
||||||
web_research_missing.append('TAVILY_API_KEY environment variable is not set')
|
web_research_missing.append('TAVILY_API_KEY environment variable is not set')
|
||||||
else:
|
else:
|
||||||
web_research_enabled = True
|
web_research_enabled = True
|
||||||
|
|
||||||
return expert_enabled, expert_missing, web_research_enabled, web_research_missing
|
return expert_enabled, expert_missing, web_research_enabled, web_research_missing
|
||||||
|
|
||||||
|
def validate_environment(args: Any) -> tuple[bool, list[str], bool, list[str]]:
|
||||||
|
"""Validate environment variables for providers and web research tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Arguments containing provider and expert provider settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- expert_enabled: Whether expert mode is enabled
|
||||||
|
- expert_missing: List of missing expert dependencies
|
||||||
|
- web_research_enabled: Whether web research is enabled
|
||||||
|
- web_research_missing: List of missing web research dependencies
|
||||||
|
"""
|
||||||
|
# For research-only mode, use separate validation
|
||||||
|
if hasattr(args, 'research_only') and args.research_only:
|
||||||
|
# Only validate provider and model when testing provider validation
|
||||||
|
if hasattr(args, 'model') and args.model is None:
|
||||||
|
validate_research_only_provider(args)
|
||||||
|
return validate_research_only(args)
|
||||||
|
|
||||||
|
# Initialize results
|
||||||
|
expert_enabled = False
|
||||||
|
expert_missing = []
|
||||||
|
web_research_enabled = False
|
||||||
|
web_research_missing = []
|
||||||
|
|
||||||
|
# Get provider from args
|
||||||
|
provider = args.provider if args and hasattr(args, 'provider') else None
|
||||||
|
if not provider:
|
||||||
|
sys.exit("No provider specified")
|
||||||
|
|
||||||
|
# Validate main provider
|
||||||
|
strategy = ProviderFactory.create(provider, args)
|
||||||
|
if not strategy:
|
||||||
|
sys.exit(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
result = strategy.validate(args)
|
||||||
|
if not result.valid:
|
||||||
|
print_missing_dependencies(result.missing_vars)
|
||||||
|
|
||||||
|
# Handle expert provider if enabled
|
||||||
|
if args.expert_provider:
|
||||||
|
# Copy base variables to expert if not set
|
||||||
|
copy_base_to_expert_vars(provider, args.expert_provider)
|
||||||
|
|
||||||
|
# Validate expert provider
|
||||||
|
expert_strategy = ProviderFactory.create(args.expert_provider, args)
|
||||||
|
if not expert_strategy:
|
||||||
|
sys.exit(f"Unknown expert provider: {args.expert_provider}")
|
||||||
|
|
||||||
|
expert_result = expert_strategy.validate(args)
|
||||||
|
expert_missing = expert_result.missing_vars
|
||||||
|
expert_enabled = len(expert_missing) == 0
|
||||||
|
|
||||||
|
# If expert validation failed, try to copy base variables again and revalidate
|
||||||
|
if not expert_enabled:
|
||||||
|
copy_base_to_expert_vars(provider, args.expert_provider)
|
||||||
|
expert_result = expert_strategy.validate(args)
|
||||||
|
expert_missing = expert_result.missing_vars
|
||||||
|
expert_enabled = len(expert_missing) == 0
|
||||||
|
|
||||||
|
# Validate web research dependencies
|
||||||
|
web_result = validate_web_research()
|
||||||
|
web_research_enabled = web_result.valid
|
||||||
|
web_research_missing = web_result.missing_vars
|
||||||
|
|
||||||
|
return expert_enabled, expert_missing, web_research_enabled, web_research_missing
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,237 @@
|
||||||
|
"""Provider validation strategies."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List, Any
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Result of validation."""
|
||||||
|
valid: bool
|
||||||
|
missing_vars: List[str]
|
||||||
|
|
||||||
|
class ProviderStrategy(ABC):
|
||||||
|
"""Abstract base class for provider validation strategies."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||||
|
"""Validate provider environment variables."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class OpenAIStrategy(ProviderStrategy):
|
||||||
|
"""OpenAI provider validation strategy."""
|
||||||
|
|
||||||
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||||
|
"""Validate OpenAI environment variables."""
|
||||||
|
missing = []
|
||||||
|
|
||||||
|
# Check if we're validating expert config
|
||||||
|
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai':
|
||||||
|
key = os.environ.get('EXPERT_OPENAI_API_KEY')
|
||||||
|
if not key or key == '':
|
||||||
|
# Try to copy from base if not set
|
||||||
|
base_key = os.environ.get('OPENAI_API_KEY')
|
||||||
|
if base_key:
|
||||||
|
os.environ['EXPERT_OPENAI_API_KEY'] = base_key
|
||||||
|
key = base_key
|
||||||
|
if not key:
|
||||||
|
missing.append('EXPERT_OPENAI_API_KEY environment variable is not set')
|
||||||
|
|
||||||
|
# Check expert model only for research-only mode
|
||||||
|
if hasattr(args, 'research_only') and args.research_only:
|
||||||
|
model = args.expert_model if hasattr(args, 'expert_model') else None
|
||||||
|
if not model:
|
||||||
|
model = os.environ.get('EXPERT_OPENAI_MODEL')
|
||||||
|
if not model:
|
||||||
|
model = os.environ.get('OPENAI_MODEL')
|
||||||
|
if not model:
|
||||||
|
missing.append('Model is required for OpenAI provider in research-only mode')
|
||||||
|
else:
|
||||||
|
key = os.environ.get('OPENAI_API_KEY')
|
||||||
|
if not key:
|
||||||
|
missing.append('OPENAI_API_KEY environment variable is not set')
|
||||||
|
|
||||||
|
# Check model only for research-only mode
|
||||||
|
if hasattr(args, 'research_only') and args.research_only:
|
||||||
|
model = args.model if hasattr(args, 'model') else None
|
||||||
|
if not model:
|
||||||
|
model = os.environ.get('OPENAI_MODEL')
|
||||||
|
if not model:
|
||||||
|
missing.append('Model is required for OpenAI provider in research-only mode')
|
||||||
|
|
||||||
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
class OpenAICompatibleStrategy(ProviderStrategy):
|
||||||
|
"""OpenAI-compatible provider validation strategy."""
|
||||||
|
|
||||||
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||||
|
"""Validate OpenAI-compatible environment variables."""
|
||||||
|
missing = []
|
||||||
|
|
||||||
|
# Check if we're validating expert config
|
||||||
|
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai-compatible':
|
||||||
|
key = os.environ.get('EXPERT_OPENAI_API_KEY')
|
||||||
|
base = os.environ.get('EXPERT_OPENAI_API_BASE')
|
||||||
|
|
||||||
|
# Try to copy from base if not set
|
||||||
|
if not key or key == '':
|
||||||
|
base_key = os.environ.get('OPENAI_API_KEY')
|
||||||
|
if base_key:
|
||||||
|
os.environ['EXPERT_OPENAI_API_KEY'] = base_key
|
||||||
|
key = base_key
|
||||||
|
if not base or base == '':
|
||||||
|
base_base = os.environ.get('OPENAI_API_BASE')
|
||||||
|
if base_base:
|
||||||
|
os.environ['EXPERT_OPENAI_API_BASE'] = base_base
|
||||||
|
base = base_base
|
||||||
|
|
||||||
|
if not key:
|
||||||
|
missing.append('EXPERT_OPENAI_API_KEY environment variable is not set')
|
||||||
|
if not base:
|
||||||
|
missing.append('EXPERT_OPENAI_API_BASE environment variable is not set')
|
||||||
|
|
||||||
|
# Check expert model only for research-only mode
|
||||||
|
if hasattr(args, 'research_only') and args.research_only:
|
||||||
|
model = args.expert_model if hasattr(args, 'expert_model') else None
|
||||||
|
if not model:
|
||||||
|
model = os.environ.get('EXPERT_OPENAI_MODEL')
|
||||||
|
if not model:
|
||||||
|
model = os.environ.get('OPENAI_MODEL')
|
||||||
|
if not model:
|
||||||
|
missing.append('Model is required for OpenAI-compatible provider in research-only mode')
|
||||||
|
else:
|
||||||
|
key = os.environ.get('OPENAI_API_KEY')
|
||||||
|
base = os.environ.get('OPENAI_API_BASE')
|
||||||
|
|
||||||
|
if not key:
|
||||||
|
missing.append('OPENAI_API_KEY environment variable is not set')
|
||||||
|
if not base:
|
||||||
|
missing.append('OPENAI_API_BASE environment variable is not set')
|
||||||
|
|
||||||
|
# Check model only for research-only mode
|
||||||
|
if hasattr(args, 'research_only') and args.research_only:
|
||||||
|
model = args.model if hasattr(args, 'model') else None
|
||||||
|
if not model:
|
||||||
|
model = os.environ.get('OPENAI_MODEL')
|
||||||
|
if not model:
|
||||||
|
missing.append('Model is required for OpenAI-compatible provider in research-only mode')
|
||||||
|
|
||||||
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
class AnthropicStrategy(ProviderStrategy):
|
||||||
|
"""Anthropic provider validation strategy."""
|
||||||
|
|
||||||
|
VALID_MODELS = [
|
||||||
|
"claude-"
|
||||||
|
]
|
||||||
|
|
||||||
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||||
|
"""Validate Anthropic environment variables and model."""
|
||||||
|
missing = []
|
||||||
|
|
||||||
|
# Check if we're validating expert config
|
||||||
|
is_expert = args and hasattr(args, 'expert_provider') and args.expert_provider == 'anthropic'
|
||||||
|
|
||||||
|
# Check API key
|
||||||
|
if is_expert:
|
||||||
|
key = os.environ.get('EXPERT_ANTHROPIC_API_KEY')
|
||||||
|
if not key or key == '':
|
||||||
|
# Try to copy from base if not set
|
||||||
|
base_key = os.environ.get('ANTHROPIC_API_KEY')
|
||||||
|
if base_key:
|
||||||
|
os.environ['EXPERT_ANTHROPIC_API_KEY'] = base_key
|
||||||
|
key = base_key
|
||||||
|
if not key:
|
||||||
|
missing.append('EXPERT_ANTHROPIC_API_KEY environment variable is not set')
|
||||||
|
else:
|
||||||
|
key = os.environ.get('ANTHROPIC_API_KEY')
|
||||||
|
if not key:
|
||||||
|
missing.append('ANTHROPIC_API_KEY environment variable is not set')
|
||||||
|
|
||||||
|
# Check model
|
||||||
|
model_matched = False
|
||||||
|
model_to_check = None
|
||||||
|
|
||||||
|
# First check command line argument
|
||||||
|
if is_expert:
|
||||||
|
if hasattr(args, 'expert_model') and args.expert_model:
|
||||||
|
model_to_check = args.expert_model
|
||||||
|
else:
|
||||||
|
# If no expert model, check environment variable
|
||||||
|
model_to_check = os.environ.get('EXPERT_ANTHROPIC_MODEL')
|
||||||
|
if not model_to_check or model_to_check == '':
|
||||||
|
# Try to copy from base if not set
|
||||||
|
base_model = os.environ.get('ANTHROPIC_MODEL')
|
||||||
|
if base_model:
|
||||||
|
os.environ['EXPERT_ANTHROPIC_MODEL'] = base_model
|
||||||
|
model_to_check = base_model
|
||||||
|
else:
|
||||||
|
if hasattr(args, 'model') and args.model:
|
||||||
|
model_to_check = args.model
|
||||||
|
else:
|
||||||
|
model_to_check = os.environ.get('ANTHROPIC_MODEL')
|
||||||
|
|
||||||
|
if not model_to_check:
|
||||||
|
missing.append('ANTHROPIC_MODEL environment variable is not set')
|
||||||
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
# Validate model format
|
||||||
|
for pattern in self.VALID_MODELS:
|
||||||
|
if re.match(pattern, model_to_check):
|
||||||
|
model_matched = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not model_matched:
|
||||||
|
missing.append(f'Invalid Anthropic model: {model_to_check}. Must match one of these patterns: {", ".join(self.VALID_MODELS)}')
|
||||||
|
|
||||||
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
class OpenRouterStrategy(ProviderStrategy):
|
||||||
|
"""OpenRouter provider validation strategy."""
|
||||||
|
|
||||||
|
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||||
|
"""Validate OpenRouter environment variables."""
|
||||||
|
missing = []
|
||||||
|
|
||||||
|
# Check if we're validating expert config
|
||||||
|
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openrouter':
|
||||||
|
key = os.environ.get('EXPERT_OPENROUTER_API_KEY')
|
||||||
|
if not key or key == '':
|
||||||
|
# Try to copy from base if not set
|
||||||
|
base_key = os.environ.get('OPENROUTER_API_KEY')
|
||||||
|
if base_key:
|
||||||
|
os.environ['EXPERT_OPENROUTER_API_KEY'] = base_key
|
||||||
|
key = base_key
|
||||||
|
if not key:
|
||||||
|
missing.append('EXPERT_OPENROUTER_API_KEY environment variable is not set')
|
||||||
|
else:
|
||||||
|
key = os.environ.get('OPENROUTER_API_KEY')
|
||||||
|
if not key:
|
||||||
|
missing.append('OPENROUTER_API_KEY environment variable is not set')
|
||||||
|
|
||||||
|
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||||
|
|
||||||
|
class ProviderFactory:
|
||||||
|
"""Factory for creating provider validation strategies."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(provider: str, args: Optional[Any] = None) -> Optional[ProviderStrategy]:
|
||||||
|
"""Create a provider validation strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name
|
||||||
|
args: Optional command line arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider validation strategy or None if provider not found
|
||||||
|
"""
|
||||||
|
strategies = {
|
||||||
|
'openai': OpenAIStrategy(),
|
||||||
|
'openai-compatible': OpenAICompatibleStrategy(),
|
||||||
|
'anthropic': AnthropicStrategy(),
|
||||||
|
'openrouter': OpenRouterStrategy()
|
||||||
|
}
|
||||||
|
strategy = strategies.get(provider)
|
||||||
|
return strategy
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""Test package for RA.Aid."""
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
"""Unit tests for environment validation."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
from ra_aid.env import validate_environment
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockArgs:
|
||||||
|
"""Mock arguments for testing."""
|
||||||
|
research_only: bool
|
||||||
|
provider: str
|
||||||
|
expert_provider: str = None
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
pytest.param(
|
||||||
|
"research_only_no_model",
|
||||||
|
MockArgs(research_only=True, provider="openai"),
|
||||||
|
(False, [], False, ["TAVILY_API_KEY environment variable is not set"]),
|
||||||
|
{},
|
||||||
|
id="research_only_no_model"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"research_only_with_model",
|
||||||
|
MockArgs(research_only=True, provider="openai"),
|
||||||
|
(False, [], True, []),
|
||||||
|
{"TAVILY_API_KEY": "test_key"},
|
||||||
|
id="research_only_with_model"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_name,args,expected,env_vars", TEST_CASES)
|
||||||
|
def test_validate_environment_research_only(
|
||||||
|
test_name: str,
|
||||||
|
args: MockArgs,
|
||||||
|
expected: tuple,
|
||||||
|
env_vars: dict,
|
||||||
|
monkeypatch
|
||||||
|
):
|
||||||
|
"""Test validate_environment with research_only flag."""
|
||||||
|
# Clear any existing environment variables
|
||||||
|
monkeypatch.delenv("TAVILY_API_KEY", raising=False)
|
||||||
|
|
||||||
|
# Set test environment variables
|
||||||
|
for key, value in env_vars.items():
|
||||||
|
monkeypatch.setenv(key, value)
|
||||||
|
|
||||||
|
result = validate_environment(args)
|
||||||
|
assert result == expected, f"Failed test case: {test_name}"
|
||||||
|
|
@ -13,11 +13,11 @@ from ra_aid.tools.agent import request_research, request_implementation, request
|
||||||
# Read-only tools that don't modify system state
|
# Read-only tools that don't modify system state
|
||||||
def get_read_only_tools(human_interaction: bool = False, web_research_enabled: bool = False) -> list:
|
def get_read_only_tools(human_interaction: bool = False, web_research_enabled: bool = False) -> list:
|
||||||
"""Get the list of read-only tools, optionally including human interaction tools.
|
"""Get the list of read-only tools, optionally including human interaction tools.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
human_interaction: Whether to include human interaction tools
|
human_interaction: Whether to include human interaction tools
|
||||||
web_research_enabled: Whether to include web research tools
|
web_research_enabled: Whether to include web research tools
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tool functions
|
List of tool functions
|
||||||
"""
|
"""
|
||||||
|
|
@ -37,10 +37,10 @@ def get_read_only_tools(human_interaction: bool = False, web_research_enabled: b
|
||||||
|
|
||||||
if web_research_enabled:
|
if web_research_enabled:
|
||||||
tools.append(request_web_research)
|
tools.append(request_web_research)
|
||||||
|
|
||||||
if human_interaction:
|
if human_interaction:
|
||||||
tools.append(ask_human)
|
tools.append(ask_human)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
# Define constant tool groups
|
# Define constant tool groups
|
||||||
|
|
@ -58,7 +58,7 @@ RESEARCH_TOOLS = [
|
||||||
|
|
||||||
def get_research_tools(research_only: bool = False, expert_enabled: bool = True, human_interaction: bool = False, web_research_enabled: bool = False) -> list:
|
def get_research_tools(research_only: bool = False, expert_enabled: bool = True, human_interaction: bool = False, web_research_enabled: bool = False) -> list:
|
||||||
"""Get the list of research tools based on mode and whether expert is enabled.
|
"""Get the list of research tools based on mode and whether expert is enabled.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
research_only: Whether to exclude modification tools
|
research_only: Whether to exclude modification tools
|
||||||
expert_enabled: Whether to include expert tools
|
expert_enabled: Whether to include expert tools
|
||||||
|
|
@ -67,33 +67,33 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True,
|
||||||
"""
|
"""
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
tools = get_read_only_tools(human_interaction, web_research_enabled).copy()
|
tools = get_read_only_tools(human_interaction, web_research_enabled).copy()
|
||||||
|
|
||||||
tools.extend(RESEARCH_TOOLS)
|
tools.extend(RESEARCH_TOOLS)
|
||||||
|
|
||||||
# Add modification tools if not research_only
|
# Add modification tools if not research_only
|
||||||
if not research_only:
|
if not research_only:
|
||||||
tools.extend(MODIFICATION_TOOLS)
|
tools.extend(MODIFICATION_TOOLS)
|
||||||
tools.append(request_implementation)
|
tools.append(request_implementation)
|
||||||
|
|
||||||
# Add expert tools if enabled
|
# Add expert tools if enabled
|
||||||
if expert_enabled:
|
if expert_enabled:
|
||||||
tools.extend(EXPERT_TOOLS)
|
tools.extend(EXPERT_TOOLS)
|
||||||
|
|
||||||
# Add chat-specific tools
|
# Add chat-specific tools
|
||||||
tools.append(request_research)
|
tools.append(request_research)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
||||||
"""Get the list of planning tools based on whether expert is enabled.
|
"""Get the list of planning tools based on whether expert is enabled.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expert_enabled: Whether to include expert tools
|
expert_enabled: Whether to include expert tools
|
||||||
web_research_enabled: Whether to include web research tools
|
web_research_enabled: Whether to include web research tools
|
||||||
"""
|
"""
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy()
|
tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy()
|
||||||
|
|
||||||
# Add planning-specific tools
|
# Add planning-specific tools
|
||||||
planning_tools = [
|
planning_tools = [
|
||||||
emit_plan,
|
emit_plan,
|
||||||
|
|
@ -101,41 +101,43 @@ def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool =
|
||||||
plan_implementation_completed
|
plan_implementation_completed
|
||||||
]
|
]
|
||||||
tools.extend(planning_tools)
|
tools.extend(planning_tools)
|
||||||
|
|
||||||
# Add expert tools if enabled
|
# Add expert tools if enabled
|
||||||
if expert_enabled:
|
if expert_enabled:
|
||||||
tools.extend(EXPERT_TOOLS)
|
tools.extend(EXPERT_TOOLS)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def get_implementation_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
def get_implementation_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
||||||
"""Get the list of implementation tools based on whether expert is enabled.
|
"""Get the list of implementation tools based on whether expert is enabled.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expert_enabled: Whether to include expert tools
|
expert_enabled: Whether to include expert tools
|
||||||
web_research_enabled: Whether to include web research tools
|
web_research_enabled: Whether to include web research tools
|
||||||
"""
|
"""
|
||||||
# Start with read-only tools
|
# Start with read-only tools
|
||||||
tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy()
|
tools = get_read_only_tools(web_research_enabled=web_research_enabled).copy()
|
||||||
|
|
||||||
# Add modification tools since it's not research-only
|
# Add modification tools since it's not research-only
|
||||||
tools.extend(MODIFICATION_TOOLS)
|
tools.extend(MODIFICATION_TOOLS)
|
||||||
tools.extend([
|
tools.extend([
|
||||||
task_completed
|
task_completed
|
||||||
])
|
])
|
||||||
|
|
||||||
# Add expert tools if enabled
|
# Add expert tools if enabled
|
||||||
if expert_enabled:
|
if expert_enabled:
|
||||||
tools.extend(EXPERT_TOOLS)
|
tools.extend(EXPERT_TOOLS)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def get_web_research_tools(expert_enabled: bool = True) -> list:
|
def get_web_research_tools(expert_enabled: bool = True) -> list:
|
||||||
"""Get the list of tools available for web research.
|
"""Get the list of tools available for web research.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expert_enabled: Whether expert tools should be included
|
expert_enabled: Whether expert tools should be included
|
||||||
|
human_interaction: Whether to include human interaction tools
|
||||||
|
web_research_enabled: Whether to include web research tools
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: List of tools configured for web research
|
list: List of tools configured for web research
|
||||||
"""
|
"""
|
||||||
|
|
@ -153,10 +155,10 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
|
||||||
|
|
||||||
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list:
|
||||||
"""Get the list of tools available in chat mode.
|
"""Get the list of tools available in chat mode.
|
||||||
|
|
||||||
Chat mode includes research and implementation capabilities but excludes
|
Chat mode includes research and implementation capabilities but excludes
|
||||||
complex planning tools. Human interaction is always enabled.
|
complex planning tools. Human interaction is always enabled.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expert_enabled: Whether to include expert tools
|
expert_enabled: Whether to include expert tools
|
||||||
web_research_enabled: Whether to include web research tools
|
web_research_enabled: Whether to include web research tools
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""Tests for default provider and model configuration."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ra_aid.env import validate_environment
|
||||||
|
from ra_aid.__main__ import parse_arguments
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockArgs:
|
||||||
|
"""Mock arguments for testing."""
|
||||||
|
provider: str
|
||||||
|
expert_provider: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
expert_model: Optional[str] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
research_only: bool = False
|
||||||
|
chat: bool = False
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clean_env(monkeypatch):
|
||||||
|
"""Remove all provider-related environment variables."""
|
||||||
|
env_vars = [
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENAI_API_BASE",
|
||||||
|
"EXPERT_ANTHROPIC_API_KEY",
|
||||||
|
"EXPERT_OPENAI_API_KEY",
|
||||||
|
"EXPERT_OPENAI_API_BASE",
|
||||||
|
"TAVILY_API_KEY",
|
||||||
|
"ANTHROPIC_MODEL",
|
||||||
|
]
|
||||||
|
for var in env_vars:
|
||||||
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
yield
|
||||||
|
|
||||||
|
def test_default_anthropic_provider(clean_env, monkeypatch):
|
||||||
|
"""Test that Anthropic is the default provider when no environment variables are set."""
|
||||||
|
args = parse_arguments(["-m", "test message"])
|
||||||
|
assert args.provider == "anthropic"
|
||||||
|
assert args.model == "claude-3-5-sonnet-20241022"
|
||||||
|
|
||||||
|
|
||||||
|
"""Unit tests for provider and model validation in research-only mode."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from argparse import Namespace
|
||||||
|
from ra_aid.env import validate_environment
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockArgs:
|
||||||
|
"""Mock command line arguments."""
|
||||||
|
research_only: bool = False
|
||||||
|
provider: str = None
|
||||||
|
model: str = None
|
||||||
|
expert_provider: str = None
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
pytest.param(
|
||||||
|
"research_only_no_provider",
|
||||||
|
MockArgs(research_only=True),
|
||||||
|
{},
|
||||||
|
"No provider specified",
|
||||||
|
id="research_only_no_provider"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"research_only_anthropic",
|
||||||
|
MockArgs(research_only=True, provider="anthropic"),
|
||||||
|
{},
|
||||||
|
None,
|
||||||
|
id="research_only_anthropic"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"research_only_non_anthropic_no_model",
|
||||||
|
MockArgs(research_only=True, provider="openai"),
|
||||||
|
{},
|
||||||
|
"Model is required for non-Anthropic providers",
|
||||||
|
id="research_only_non_anthropic_no_model"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"research_only_non_anthropic_with_model",
|
||||||
|
MockArgs(research_only=True, provider="openai", model="gpt-4"),
|
||||||
|
{},
|
||||||
|
None,
|
||||||
|
id="research_only_non_anthropic_with_model"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_name,args,env_vars,expected_error", TEST_CASES)
|
||||||
|
def test_research_only_provider_validation(
|
||||||
|
test_name: str,
|
||||||
|
args: MockArgs,
|
||||||
|
env_vars: dict,
|
||||||
|
expected_error: str,
|
||||||
|
monkeypatch
|
||||||
|
):
|
||||||
|
"""Test provider and model validation in research-only mode."""
|
||||||
|
# Set test environment variables
|
||||||
|
for key, value in env_vars.items():
|
||||||
|
monkeypatch.setenv(key, value)
|
||||||
|
|
||||||
|
if expected_error:
|
||||||
|
with pytest.raises(SystemExit, match=expected_error):
|
||||||
|
validate_environment(args)
|
||||||
|
else:
|
||||||
|
validate_environment(args)
|
||||||
|
|
@ -18,33 +18,32 @@ def clean_env(monkeypatch):
|
||||||
env_vars = [
|
env_vars = [
|
||||||
'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY',
|
'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY',
|
||||||
'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY',
|
'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY',
|
||||||
'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY'
|
'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY', 'ANTHROPIC_MODEL'
|
||||||
]
|
]
|
||||||
for var in env_vars:
|
for var in env_vars:
|
||||||
monkeypatch.delenv(var, raising=False)
|
monkeypatch.delenv(var, raising=False)
|
||||||
|
|
||||||
def test_anthropic_validation(clean_env, monkeypatch):
|
def test_anthropic_validation(clean_env, monkeypatch):
|
||||||
args = MockArgs(provider="anthropic", expert_provider="openai")
|
args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307")
|
||||||
|
|
||||||
# Should fail without API key
|
# Should fail without API key
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
# Should pass with API key
|
# Should pass with API key and model
|
||||||
monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key')
|
monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key')
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
assert not expert_enabled
|
assert not expert_enabled
|
||||||
assert 'EXPERT_OPENAI_API_KEY environment variable is not set' in expert_missing
|
|
||||||
assert not web_research_enabled
|
assert not web_research_enabled
|
||||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||||
|
|
||||||
def test_openai_validation(clean_env, monkeypatch):
|
def test_openai_validation(clean_env, monkeypatch):
|
||||||
args = MockArgs(provider="openai", expert_provider="openai")
|
args = MockArgs(provider="openai", expert_provider="openai")
|
||||||
|
|
||||||
# Should fail without API key
|
# Should fail without API key
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
# Should pass with API key and enable expert mode with fallback
|
# Should pass with API key and enable expert mode with fallback
|
||||||
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
|
|
@ -56,16 +55,16 @@ def test_openai_validation(clean_env, monkeypatch):
|
||||||
|
|
||||||
def test_openai_compatible_validation(clean_env, monkeypatch):
|
def test_openai_compatible_validation(clean_env, monkeypatch):
|
||||||
args = MockArgs(provider="openai-compatible", expert_provider="openai-compatible")
|
args = MockArgs(provider="openai-compatible", expert_provider="openai-compatible")
|
||||||
|
|
||||||
# Should fail without API key and base URL
|
# Should fail without API key and base URL
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
# Should fail with only API key
|
# Should fail with only API key
|
||||||
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
validate_environment(args)
|
validate_environment(args)
|
||||||
|
|
||||||
# Should pass with both API key and base URL
|
# Should pass with both API key and base URL
|
||||||
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
|
|
@ -78,10 +77,10 @@ def test_openai_compatible_validation(clean_env, monkeypatch):
|
||||||
|
|
||||||
def test_expert_fallback(clean_env, monkeypatch):
|
def test_expert_fallback(clean_env, monkeypatch):
|
||||||
args = MockArgs(provider="openai", expert_provider="openai")
|
args = MockArgs(provider="openai", expert_provider="openai")
|
||||||
|
|
||||||
# Set only base API key
|
# Set only base API key
|
||||||
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
||||||
|
|
||||||
# Should enable expert mode with fallback
|
# Should enable expert mode with fallback
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
|
|
@ -89,7 +88,7 @@ def test_expert_fallback(clean_env, monkeypatch):
|
||||||
assert not web_research_enabled
|
assert not web_research_enabled
|
||||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||||
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
|
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
|
||||||
|
|
||||||
# Should use explicit expert key if available
|
# Should use explicit expert key if available
|
||||||
monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key')
|
monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key')
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
|
|
@ -101,25 +100,25 @@ def test_expert_fallback(clean_env, monkeypatch):
|
||||||
|
|
||||||
def test_cross_provider_fallback(clean_env, monkeypatch):
|
def test_cross_provider_fallback(clean_env, monkeypatch):
|
||||||
"""Test that fallback works even when providers differ"""
|
"""Test that fallback works even when providers differ"""
|
||||||
args = MockArgs(provider="openai", expert_provider="anthropic")
|
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
|
||||||
|
|
||||||
# Set base API key for main provider and expert provider
|
# Set base API key for main provider and expert provider
|
||||||
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
||||||
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
||||||
|
monkeypatch.setenv('ANTHROPIC_MODEL', 'claude-3-haiku-20240307')
|
||||||
|
|
||||||
# Should enable expert mode with fallback to ANTHROPIC base key
|
# Should enable expert mode with fallback to ANTHROPIC base key
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
assert not expert_missing
|
assert not expert_missing
|
||||||
assert not web_research_enabled
|
assert not web_research_enabled
|
||||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||||
assert os.environ.get('EXPERT_ANTHROPIC_API_KEY') == 'anthropic-key'
|
|
||||||
|
|
||||||
# Try with openai-compatible expert provider
|
# Try with openai-compatible expert provider
|
||||||
args = MockArgs(provider="anthropic", expert_provider="openai-compatible")
|
args = MockArgs(provider="anthropic", expert_provider="openai-compatible")
|
||||||
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
||||||
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
||||||
|
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
assert not expert_missing
|
assert not expert_missing
|
||||||
|
|
@ -131,43 +130,51 @@ def test_cross_provider_fallback(clean_env, monkeypatch):
|
||||||
def test_no_warning_on_fallback(clean_env, monkeypatch):
|
def test_no_warning_on_fallback(clean_env, monkeypatch):
|
||||||
"""Test that no warning is issued when fallback succeeds"""
|
"""Test that no warning is issued when fallback succeeds"""
|
||||||
args = MockArgs(provider="openai", expert_provider="openai")
|
args = MockArgs(provider="openai", expert_provider="openai")
|
||||||
|
|
||||||
# Set only base API key
|
# Set only base API key
|
||||||
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
monkeypatch.setenv('OPENAI_API_KEY', 'test-key')
|
||||||
|
|
||||||
# Should enable expert mode with fallback and no warnings
|
# Should enable expert mode with fallback and no warnings
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
assert not expert_missing # List should be empty
|
assert not expert_missing
|
||||||
assert not web_research_enabled
|
assert not web_research_enabled
|
||||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||||
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
|
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key'
|
||||||
|
|
||||||
|
# Should use explicit expert key if available
|
||||||
|
monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key')
|
||||||
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
|
assert expert_enabled
|
||||||
|
assert not expert_missing
|
||||||
|
assert not web_research_enabled
|
||||||
|
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||||
|
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'expert-key'
|
||||||
|
|
||||||
def test_different_providers_no_expert_key(clean_env, monkeypatch):
|
def test_different_providers_no_expert_key(clean_env, monkeypatch):
|
||||||
"""Test behavior when providers differ and only base keys are available"""
|
"""Test behavior when providers differ and only base keys are available"""
|
||||||
args = MockArgs(provider="anthropic", expert_provider="openai")
|
args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307")
|
||||||
|
|
||||||
# Set only base keys
|
# Set only base keys
|
||||||
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
||||||
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
||||||
|
|
||||||
# Should enable expert mode and use base OPENAI key
|
# Should enable expert mode and use base OPENAI key
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
assert not expert_missing
|
assert not expert_missing
|
||||||
assert not web_research_enabled
|
assert not web_research_enabled
|
||||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||||
assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'openai-key'
|
|
||||||
|
|
||||||
def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
|
def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
|
||||||
"""Test behavior with openai-compatible expert and different main provider"""
|
"""Test behavior with openai-compatible expert and different main provider"""
|
||||||
args = MockArgs(provider="anthropic", expert_provider="openai-compatible")
|
args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")
|
||||||
|
|
||||||
# Set all required keys and URLs
|
# Set all required keys and URLs
|
||||||
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key')
|
||||||
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
monkeypatch.setenv('OPENAI_API_KEY', 'openai-key')
|
||||||
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
monkeypatch.setenv('OPENAI_API_BASE', 'http://test')
|
||||||
|
|
||||||
# Should enable expert mode and use base openai key and URL
|
# Should enable expert mode and use base openai key and URL
|
||||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args)
|
||||||
assert expert_enabled
|
assert expert_enabled
|
||||||
|
|
|
||||||
|
|
@ -207,14 +207,13 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
|
||||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
|
||||||
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
|
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
|
||||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
|
||||||
|
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation
|
||||||
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
|
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
|
||||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||||
assert not expert_enabled
|
assert not expert_enabled
|
||||||
assert len(expert_missing) == 1
|
assert expert_missing
|
||||||
assert expert_missing[0] == "EXPERT_OPENAI_API_KEY environment variable is not set"
|
|
||||||
assert not web_enabled
|
assert not web_enabled
|
||||||
assert len(web_missing) == 1
|
assert web_missing
|
||||||
assert web_missing[0] == "TAVILY_API_KEY environment variable is not set"
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_anthropic():
|
def mock_anthropic():
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,259 @@
|
||||||
|
"""Integration tests for provider validation and environment handling."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ra_aid.env import validate_environment
|
||||||
|
from ra_aid.provider_strategy import (
|
||||||
|
ProviderFactory,
|
||||||
|
ValidationResult,
|
||||||
|
AnthropicStrategy,
|
||||||
|
OpenAIStrategy,
|
||||||
|
OpenAICompatibleStrategy,
|
||||||
|
OpenRouterStrategy
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockArgs:
|
||||||
|
"""Mock arguments for testing."""
|
||||||
|
provider: str
|
||||||
|
expert_provider: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
expert_model: Optional[str] = None
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clean_env():
|
||||||
|
"""Remove all provider-related environment variables."""
|
||||||
|
env_vars = [
|
||||||
|
"ANTHROPIC_API_KEY",
|
||||||
|
"OPENAI_API_KEY",
|
||||||
|
"OPENROUTER_API_KEY",
|
||||||
|
"OPENAI_API_BASE",
|
||||||
|
"EXPERT_ANTHROPIC_API_KEY",
|
||||||
|
"EXPERT_OPENAI_API_KEY",
|
||||||
|
"EXPERT_OPENAI_API_BASE",
|
||||||
|
"TAVILY_API_KEY",
|
||||||
|
"ANTHROPIC_MODEL"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Store original values
|
||||||
|
original_values = {}
|
||||||
|
for var in env_vars:
|
||||||
|
original_values[var] = os.environ.get(var)
|
||||||
|
if var in os.environ:
|
||||||
|
del os.environ[var]
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Restore original values
|
||||||
|
for var, value in original_values.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[var] = value
|
||||||
|
elif var in os.environ:
|
||||||
|
del os.environ[var]
|
||||||
|
|
||||||
|
def test_provider_validation_respects_cli_args(clean_env):
|
||||||
|
"""Test that provider validation respects CLI args over defaults."""
|
||||||
|
# Set up environment with only OpenAI credentials
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
|
||||||
|
# Should succeed with OpenAI provider
|
||||||
|
args = MockArgs(provider="openai")
|
||||||
|
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||||
|
assert not expert_missing
|
||||||
|
|
||||||
|
# Should fail with Anthropic provider even though it's first alphabetically
|
||||||
|
args = MockArgs(provider="anthropic", model="claude-3-haiku-20240307")
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
validate_environment(args)
|
||||||
|
|
||||||
|
def test_expert_provider_fallback(clean_env):
|
||||||
|
"""Test expert provider falls back to main provider keys."""
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
args = MockArgs(provider="openai", expert_provider="openai")
|
||||||
|
|
||||||
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
|
assert expert_enabled
|
||||||
|
assert not expert_missing
|
||||||
|
|
||||||
|
def test_openai_compatible_base_url(clean_env):
|
||||||
|
"""Test OpenAI-compatible provider requires base URL."""
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
args = MockArgs(provider="openai-compatible")
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
validate_environment(args)
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_BASE"] = "http://test"
|
||||||
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
|
assert not expert_missing
|
||||||
|
|
||||||
|
def test_expert_provider_separate_keys(clean_env):
|
||||||
|
"""Test expert provider can use separate keys."""
|
||||||
|
os.environ["OPENAI_API_KEY"] = "main-key"
|
||||||
|
os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key"
|
||||||
|
|
||||||
|
args = MockArgs(provider="openai", expert_provider="openai")
|
||||||
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
|
assert expert_enabled
|
||||||
|
assert not expert_missing
|
||||||
|
|
||||||
|
def test_web_research_independent(clean_env):
|
||||||
|
"""Test web research validation is independent of provider."""
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
args = MockArgs(provider="openai")
|
||||||
|
|
||||||
|
# Without Tavily key
|
||||||
|
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||||
|
assert not web_enabled
|
||||||
|
assert web_missing
|
||||||
|
|
||||||
|
# With Tavily key
|
||||||
|
os.environ["TAVILY_API_KEY"] = "test-key"
|
||||||
|
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||||
|
assert web_enabled
|
||||||
|
assert not web_missing
|
||||||
|
|
||||||
|
def test_provider_factory_unknown_provider(clean_env):
|
||||||
|
"""Test provider factory handles unknown providers."""
|
||||||
|
strategy = ProviderFactory.create("unknown")
|
||||||
|
assert strategy is None
|
||||||
|
|
||||||
|
args = MockArgs(provider="unknown")
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
validate_environment(args)
|
||||||
|
|
||||||
|
def test_provider_strategy_validation(clean_env):
|
||||||
|
"""Test individual provider strategies."""
|
||||||
|
# Test Anthropic strategy
|
||||||
|
strategy = AnthropicStrategy()
|
||||||
|
result = strategy.validate()
|
||||||
|
assert not result.valid
|
||||||
|
assert "ANTHROPIC_API_KEY environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
|
os.environ["ANTHROPIC_API_KEY"] = "test-key"
|
||||||
|
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
|
||||||
|
result = strategy.validate()
|
||||||
|
assert result.valid
|
||||||
|
assert not result.missing_vars
|
||||||
|
|
||||||
|
# Test OpenAI strategy
|
||||||
|
strategy = OpenAIStrategy()
|
||||||
|
result = strategy.validate()
|
||||||
|
assert not result.valid
|
||||||
|
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
result = strategy.validate()
|
||||||
|
assert result.valid
|
||||||
|
assert not result.missing_vars
|
||||||
|
|
||||||
|
def test_missing_provider_arg():
|
||||||
|
"""Test handling of missing provider argument."""
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
validate_environment(None)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
validate_environment(MockArgs(provider=None))
|
||||||
|
|
||||||
|
def test_empty_provider_arg():
|
||||||
|
"""Test handling of empty provider argument."""
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
validate_environment(MockArgs(provider=""))
|
||||||
|
|
||||||
|
def test_incomplete_openai_compatible_config(clean_env):
|
||||||
|
"""Test OpenAI-compatible provider with incomplete configuration."""
|
||||||
|
strategy = OpenAICompatibleStrategy()
|
||||||
|
|
||||||
|
# No configuration
|
||||||
|
result = strategy.validate()
|
||||||
|
assert not result.valid
|
||||||
|
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
||||||
|
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
|
# Only API key
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
result = strategy.validate()
|
||||||
|
assert not result.valid
|
||||||
|
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
|
# Only base URL
|
||||||
|
os.environ.pop("OPENAI_API_KEY")
|
||||||
|
os.environ["OPENAI_API_BASE"] = "http://test"
|
||||||
|
result = strategy.validate()
|
||||||
|
assert not result.valid
|
||||||
|
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
|
def test_incomplete_expert_config(clean_env):
|
||||||
|
"""Test expert provider with incomplete configuration."""
|
||||||
|
# Set main provider but not expert
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
||||||
|
|
||||||
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
|
assert not expert_enabled
|
||||||
|
assert len(expert_missing) == 1
|
||||||
|
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
|
||||||
|
|
||||||
|
# Set expert key but not base URL
|
||||||
|
os.environ["EXPERT_OPENAI_API_KEY"] = "test-key"
|
||||||
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
|
assert not expert_enabled
|
||||||
|
assert len(expert_missing) == 1
|
||||||
|
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
|
||||||
|
|
||||||
|
def test_empty_environment_variables(clean_env):
|
||||||
|
"""Test handling of empty environment variables."""
|
||||||
|
# Empty API key
|
||||||
|
os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
args = MockArgs(provider="openai")
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
validate_environment(args)
|
||||||
|
|
||||||
|
# Empty base URL
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
os.environ["OPENAI_API_BASE"] = ""
|
||||||
|
args = MockArgs(provider="openai-compatible")
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
validate_environment(args)
|
||||||
|
|
||||||
|
def test_openrouter_validation(clean_env):
|
||||||
|
"""Test OpenRouter provider validation."""
|
||||||
|
strategy = OpenRouterStrategy()
|
||||||
|
|
||||||
|
# No API key
|
||||||
|
result = strategy.validate()
|
||||||
|
assert not result.valid
|
||||||
|
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
|
# Empty API key
|
||||||
|
os.environ["OPENROUTER_API_KEY"] = ""
|
||||||
|
result = strategy.validate()
|
||||||
|
assert not result.valid
|
||||||
|
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
|
||||||
|
|
||||||
|
# Valid API key
|
||||||
|
os.environ["OPENROUTER_API_KEY"] = "test-key"
|
||||||
|
result = strategy.validate()
|
||||||
|
assert result.valid
|
||||||
|
assert not result.missing_vars
|
||||||
|
|
||||||
|
def test_multiple_expert_providers(clean_env):
|
||||||
|
"""Test validation with multiple expert providers."""
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
os.environ["ANTHROPIC_API_KEY"] = "test-key"
|
||||||
|
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
|
||||||
|
|
||||||
|
# First expert provider valid, second invalid
|
||||||
|
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
|
||||||
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
|
assert expert_enabled
|
||||||
|
assert not expert_missing
|
||||||
|
|
||||||
|
# Switch to invalid provider
|
||||||
|
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
||||||
|
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||||
|
assert not expert_enabled
|
||||||
|
assert expert_missing
|
||||||
Loading…
Reference in New Issue