Added Provider/Model Override Arguments to Seperate Research/Planner Configurations (#53)

* feat: add research and planner provider/model options to enhance configurability for research and planning tasks
refactor: create get_effective_model_config function to streamline provider/model resolution logic
test: add unit tests for effective model configuration and environment validation for research and planner providers

* refactor(agent_utils.py): remove get_effective_model_config function to simplify code and improve readability
style(agent_utils.py): format debug log statements for better readability
fix(agent_utils.py): update run_agent functions to directly use config without effective model config
feat(agent_utils.py): enhance logging for command execution in programmer.py
test(tests): remove tests related to get_effective_model_config function as it has been removed

* chore(tests): remove outdated tests for research and planner agent configurations to clean up the test suite and improve maintainability

* style(tests): apply consistent formatting and spacing in test_provider_integration.py for improved readability and maintainability
This commit is contained in:
Ariel Frischer 2025-01-24 15:00:47 -08:00 committed by GitHub
parent 54fdebfc3a
commit 6c4acfea8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 3870 additions and 66 deletions

View File

@ -166,6 +166,10 @@ ra-aid -m "Add new feature" --verbose
- `--research-only`: Only perform research without implementation - `--research-only`: Only perform research without implementation
- `--provider`: The LLM provider to use (choices: anthropic, openai, openrouter, openai-compatible, gemini) - `--provider`: The LLM provider to use (choices: anthropic, openai, openrouter, openai-compatible, gemini)
- `--model`: The model name to use (required for non-Anthropic providers) - `--model`: The model name to use (required for non-Anthropic providers)
- `--research-provider`: Provider to use specifically for research tasks (falls back to --provider if not specified)
- `--research-model`: Model to use specifically for research tasks (falls back to --model if not specified)
- `--planner-provider`: Provider to use specifically for planning tasks (falls back to --provider if not specified)
- `--planner-model`: Model to use specifically for planning tasks (falls back to --model if not specified)
- `--cowboy-mode`: Skip interactive approval for shell commands - `--cowboy-mode`: Skip interactive approval for shell commands
- `--expert-provider`: The LLM provider to use for expert knowledge queries (choices: anthropic, openai, openrouter, openai-compatible, gemini) - `--expert-provider`: The LLM provider to use for expert knowledge queries (choices: anthropic, openai, openrouter, openai-compatible, gemini)
- `--expert-model`: The model name to use for expert knowledge queries (required for non-OpenAI providers) - `--expert-model`: The model name to use for expert knowledge queries (required for non-OpenAI providers)

View File

@ -7,11 +7,7 @@ from rich.console import Console
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from ra_aid.config import DEFAULT_RECURSION_LIMIT from ra_aid.config import DEFAULT_RECURSION_LIMIT
from ra_aid.env import validate_environment from ra_aid.env import validate_environment
from ra_aid.project_info import ( from ra_aid.project_info import get_project_info, format_project_info
get_project_info,
format_project_info,
display_project_status,
)
from ra_aid.tools.memory import _global_memory from ra_aid.tools.memory import _global_memory
from ra_aid.tools.human import ask_human from ra_aid.tools.human import ask_human
from ra_aid import print_stage_header, print_error from ra_aid import print_stage_header, print_error
@ -81,6 +77,26 @@ Examples:
help="The LLM provider to use", help="The LLM provider to use",
) )
parser.add_argument("--model", type=str, help="The model name to use") parser.add_argument("--model", type=str, help="The model name to use")
parser.add_argument(
"--research-provider",
type=str,
choices=VALID_PROVIDERS,
help="Provider to use specifically for research tasks",
)
parser.add_argument(
"--research-model",
type=str,
help="Model to use specifically for research tasks",
)
parser.add_argument(
"--planner-provider",
type=str,
choices=VALID_PROVIDERS,
help="Provider to use specifically for planning tasks",
)
parser.add_argument(
"--planner-model", type=str, help="Model to use specifically for planning tasks"
)
parser.add_argument( parser.add_argument(
"--cowboy-mode", "--cowboy-mode",
action="store_true", action="store_true",
@ -130,9 +146,7 @@ Examples:
help="Maximum recursion depth for agent operations (default: 100)", help="Maximum recursion depth for agent operations (default: 100)",
) )
parser.add_argument( parser.add_argument(
'--aider-config', "--aider-config", type=str, help="Specify the aider config file path"
type=str,
help='Specify the aider config file path'
) )
if args is None: if args is None:
@ -223,7 +237,7 @@ def main():
if expert_missing: if expert_missing:
console.print( console.print(
Panel( Panel(
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" "[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
+ "\n".join(f"- {m}" for m in expert_missing) + "\n".join(f"- {m}" for m in expert_missing)
+ "\nSet the required environment variables or args to enable expert mode.", + "\nSet the required environment variables or args to enable expert mode.",
title="Expert Tools Disabled", title="Expert Tools Disabled",
@ -234,7 +248,7 @@ def main():
if web_research_missing: if web_research_missing:
console.print( console.print(
Panel( Panel(
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" "[yellow]Web research disabled due to missing configuration:[/yellow]\n"
+ "\n".join(f"- {m}" for m in web_research_missing) + "\n".join(f"- {m}" for m in web_research_missing)
+ "\nSet the required environment variables to enable web research.", + "\nSet the required environment variables to enable web research.",
title="Web Research Disabled", title="Web Research Disabled",
@ -242,11 +256,12 @@ def main():
) )
) )
# Create the base model after validation
model = initialize_llm(args.provider, args.model, temperature=args.temperature)
# Handle chat mode # Handle chat mode
if args.chat: if args.chat:
# Initialize chat model with default provider/model
chat_model = initialize_llm(
args.provider, args.model, temperature=args.temperature
)
if args.research_only: if args.research_only:
print_error("Chat mode cannot be used with --research-only") print_error("Chat mode cannot be used with --research-only")
sys.exit(1) sys.exit(1)
@ -291,7 +306,7 @@ def main():
# Create chat agent with appropriate tools # Create chat agent with appropriate tools
chat_agent = create_agent( chat_agent = create_agent(
model, chat_model,
get_chat_tools( get_chat_tools(
expert_enabled=expert_enabled, expert_enabled=expert_enabled,
web_research_enabled=web_research_enabled, web_research_enabled=web_research_enabled,
@ -334,20 +349,39 @@ def main():
# Store config in global memory for access by is_informational_query # Store config in global memory for access by is_informational_query
_global_memory["config"] = config _global_memory["config"] = config
# Store model configuration # Store base provider/model configuration
_global_memory["config"]["provider"] = args.provider _global_memory["config"]["provider"] = args.provider
_global_memory["config"]["model"] = args.model _global_memory["config"]["model"] = args.model
# Store expert provider and model in config # Store expert provider/model (no fallback)
_global_memory["config"]["expert_provider"] = args.expert_provider _global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory["config"]["expert_model"] = args.expert_model _global_memory["config"]["expert_model"] = args.expert_model
# Store planner config with fallback to base values
_global_memory["config"]["planner_provider"] = (
args.planner_provider or args.provider
)
_global_memory["config"]["planner_model"] = args.planner_model or args.model
# Store research config with fallback to base values
_global_memory["config"]["research_provider"] = (
args.research_provider or args.provider
)
_global_memory["config"]["research_model"] = args.research_model or args.model
# Run research stage # Run research stage
print_stage_header("Research Stage") print_stage_header("Research Stage")
# Initialize research model with potential overrides
research_provider = args.research_provider or args.provider
research_model_name = args.research_model or args.model
research_model = initialize_llm(
research_provider, research_model_name, temperature=args.temperature
)
run_research_agent( run_research_agent(
base_task, base_task,
model, research_model,
expert_enabled=expert_enabled, expert_enabled=expert_enabled,
research_only=args.research_only, research_only=args.research_only,
hil=args.hil, hil=args.hil,
@ -357,10 +391,17 @@ def main():
# Proceed with planning and implementation if not an informational query # Proceed with planning and implementation if not an informational query
if not is_informational_query(): if not is_informational_query():
# Initialize planning model with potential overrides
planner_provider = args.planner_provider or args.provider
planner_model_name = args.planner_model or args.model
planning_model = initialize_llm(
planner_provider, planner_model_name, temperature=args.temperature
)
# Run planning agent # Run planning agent
run_planning_agent( run_planning_agent(
base_task, base_task,
model, planning_model,
expert_enabled=expert_enabled, expert_enabled=expert_enabled,
hil=args.hil, hil=args.hil,
memory=planning_memory, memory=planning_memory,

View File

@ -142,12 +142,18 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
model_info = get_model_info(provider_model) model_info = get_model_info(provider_model)
max_input_tokens = model_info.get("max_input_tokens") max_input_tokens = model_info.get("max_input_tokens")
if max_input_tokens: if max_input_tokens:
logger.debug(f"Using litellm token limit for {model_name}: {max_input_tokens}") logger.debug(
f"Using litellm token limit for {model_name}: {max_input_tokens}"
)
return max_input_tokens return max_input_tokens
except litellm.exceptions.NotFoundError: except litellm.exceptions.NotFoundError:
logger.debug(f"Model {model_name} not found in litellm, falling back to models_tokens") logger.debug(
f"Model {model_name} not found in litellm, falling back to models_tokens"
)
except Exception as e: except Exception as e:
logger.debug(f"Error getting model info from litellm: {e}, falling back to models_tokens") logger.debug(
f"Error getting model info from litellm: {e}, falling back to models_tokens"
)
# Fallback to models_tokens dict # Fallback to models_tokens dict
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2) # Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
@ -155,7 +161,9 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
provider_tokens = models_tokens.get(provider, {}) provider_tokens = models_tokens.get(provider, {})
max_input_tokens = provider_tokens.get(normalized_name, None) max_input_tokens = provider_tokens.get(normalized_name, None)
if max_input_tokens: if max_input_tokens:
logger.debug(f"Found token limit for {provider}/{model_name}: {max_input_tokens}") logger.debug(
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
)
else: else:
logger.debug(f"Could not find token limit for {provider}/{model_name}") logger.debug(f"Could not find token limit for {provider}/{model_name}")
@ -360,7 +368,10 @@ def run_research_agent(
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit} run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
}
if config: if config:
run_config.update(config) run_config.update(config)
@ -467,8 +478,12 @@ def run_web_research_agent(
) )
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit} run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
}
if config: if config:
run_config.update(config) run_config.update(config)
@ -551,8 +566,12 @@ def run_planning_agent(
) )
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit} run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
}
if config: if config:
run_config.update(config) run_config.update(config)
@ -642,7 +661,10 @@ def run_task_implementation_agent(
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit} run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
}
if config: if config:
run_config.update(config) run_config.update(config)

View File

@ -5,7 +5,9 @@ from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
def get_env_var(name: str, expert: bool = False) -> Optional[str]: def get_env_var(name: str, expert: bool = False) -> Optional[str]:
"""Get environment variable with optional expert prefix and fallback.""" """Get environment variable with optional expert prefix and fallback."""
@ -122,7 +124,18 @@ def create_llm_client(
if not config: if not config:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")
# Only pass temperature if it's explicitly set and not in expert mode logger.debug(
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
provider,
model_name,
temperature,
is_expert
)
# Handle temperature for expert mode
if is_expert:
temperature = 0
temp_kwargs = {} temp_kwargs = {}
if not is_expert and temperature is not None: if not is_expert and temperature is not None:
temp_kwargs = {"temperature": temperature} temp_kwargs = {"temperature": temperature}

View File

@ -1,5 +1,6 @@
import os import os
from typing import List, Dict, Union from typing import List, Dict, Union
from ra_aid.logging_config import get_logger
from ra_aid.tools.memory import _global_memory from ra_aid.tools.memory import _global_memory
from langchain_core.tools import tool from langchain_core.tools import tool
from rich.console import Console from rich.console import Console
@ -10,6 +11,7 @@ from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output from ra_aid.text.processing import truncate_output
console = Console() console = Console()
logger = get_logger(__name__)
@tool @tool
def run_programming_task(instructions: str, files: List[str] = []) -> Dict[str, Union[str, int, bool]]: def run_programming_task(instructions: str, files: List[str] = []) -> Dict[str, Union[str, int, bool]]:
@ -81,6 +83,7 @@ Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True
markdown_content = "".join(task_display) markdown_content = "".join(task_display)
console.print(Panel(Markdown(markdown_content), title="🤖 Aider Task", border_style="bright_blue")) console.print(Panel(Markdown(markdown_content), title="🤖 Aider Task", border_style="bright_blue"))
logger.debug(f"command: {command}")
try: try:
# Run the command interactively # Run the command interactively

View File

@ -60,41 +60,43 @@ def test_get_model_token_limit_missing_config(mock_memory):
assert token_limit is None assert token_limit is None
def test_get_model_token_limit_litellm_success(): def test_get_model_token_limit_litellm_success():
"""Test get_model_token_limit successfully getting limit from litellm.""" """Test get_model_token_limit successfully getting limit from litellm."""
config = {"provider": "anthropic", "model": "claude-2"} config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 100000} mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config)
assert token_limit == 100000 assert token_limit == 100000
def test_get_model_token_limit_litellm_not_found(): def test_get_model_token_limit_litellm_not_found():
"""Test fallback to models_tokens when litellm raises NotFoundError.""" """Test fallback to models_tokens when litellm raises NotFoundError."""
config = {"provider": "anthropic", "model": "claude-2"} config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = litellm.exceptions.NotFoundError( mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", message="Model not found", model="claude-2", llm_provider="anthropic"
model="claude-2",
llm_provider="anthropic"
) )
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config)
assert token_limit == models_tokens["anthropic"]["claude2"] assert token_limit == models_tokens["anthropic"]["claude2"]
def test_get_model_token_limit_litellm_error(): def test_get_model_token_limit_litellm_error():
"""Test fallback to models_tokens when litellm raises other exceptions.""" """Test fallback to models_tokens when litellm raises other exceptions."""
config = {"provider": "anthropic", "model": "claude-2"} config = {"provider": "anthropic", "model": "claude-2"}
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = Exception("Unknown error") mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config)
assert token_limit == models_tokens["anthropic"]["claude2"] assert token_limit == models_tokens["anthropic"]["claude2"]
def test_get_model_token_limit_unexpected_error(): def test_get_model_token_limit_unexpected_error():
"""Test returning None when unexpected errors occur.""" """Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed config = None # This will cause an attribute error when accessed
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config)
assert token_limit is None assert token_limit is None

View File

@ -11,6 +11,10 @@ class MockArgs:
expert_provider: str expert_provider: str
model: Optional[str] = None model: Optional[str] = None
expert_model: Optional[str] = None expert_model: Optional[str] = None
research_provider: Optional[str] = None
research_model: Optional[str] = None
planner_provider: Optional[str] = None
planner_model: Optional[str] = None
@pytest.fixture @pytest.fixture
def clean_env(monkeypatch): def clean_env(monkeypatch):
@ -166,6 +170,7 @@ def test_different_providers_no_expert_key(clean_env, monkeypatch):
assert not web_research_enabled assert not web_research_enabled
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
def test_mixed_provider_openai_compatible(clean_env, monkeypatch): def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
"""Test behavior with openai-compatible expert and different main provider""" """Test behavior with openai-compatible expert and different main provider"""
args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307") args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")

View File

@ -121,3 +121,45 @@ def test_missing_message():
# Verify message is captured when provided # Verify message is captured when provided
args = parse_arguments(["-m", "test"]) args = parse_arguments(["-m", "test"])
assert args.message == "test" assert args.message == "test"
def test_research_model_provider_args(mock_dependencies):
"""Test that research-specific model/provider args are correctly stored in config."""
from ra_aid.__main__ import main
import sys
from unittest.mock import patch
_global_memory.clear()
with patch.object(sys, 'argv', [
'ra-aid', '-m', 'test message',
'--research-provider', 'anthropic',
'--research-model', 'claude-3-haiku-20240307',
'--planner-provider', 'openai',
'--planner-model', 'gpt-4'
]):
main()
config = _global_memory["config"]
assert config["research_provider"] == "anthropic"
assert config["research_model"] == "claude-3-haiku-20240307"
assert config["planner_provider"] == "openai"
assert config["planner_model"] == "gpt-4"
def test_planner_model_provider_args(mock_dependencies):
"""Test that planner provider/model args fall back to main config when not specified."""
from ra_aid.__main__ import main
import sys
from unittest.mock import patch
_global_memory.clear()
with patch.object(sys, 'argv', [
'ra-aid', '-m', 'test message',
'--provider', 'openai',
'--model', 'gpt-4'
]):
main()
config = _global_memory["config"]
assert config["planner_provider"] == "openai"
assert config["planner_model"] == "gpt-4"

View File

@ -13,16 +13,23 @@ from ra_aid.provider_strategy import (
OpenAIStrategy, OpenAIStrategy,
OpenAICompatibleStrategy, OpenAICompatibleStrategy,
OpenRouterStrategy, OpenRouterStrategy,
GeminiStrategy GeminiStrategy,
) )
@dataclass @dataclass
class MockArgs: class MockArgs:
"""Mock arguments for testing.""" """Mock arguments for testing."""
provider: str provider: str
expert_provider: Optional[str] = None expert_provider: Optional[str] = None
model: Optional[str] = None model: Optional[str] = None
expert_model: Optional[str] = None expert_model: Optional[str] = None
research_provider: Optional[str] = None
research_model: Optional[str] = None
planner_provider: Optional[str] = None
planner_model: Optional[str] = None
@pytest.fixture @pytest.fixture
def clean_env(): def clean_env():
@ -39,18 +46,18 @@ def clean_env():
"ANTHROPIC_MODEL", "ANTHROPIC_MODEL",
"GEMINI_API_KEY", "GEMINI_API_KEY",
"EXPERT_GEMINI_API_KEY", "EXPERT_GEMINI_API_KEY",
"GEMINI_MODEL" "GEMINI_MODEL",
] ]
# Store original values # Store original values
original_values = {} original_values = {}
for var in env_vars: for var in env_vars:
original_values[var] = os.environ.get(var) original_values[var] = os.environ.get(var)
if var in os.environ: if var in os.environ:
del os.environ[var] del os.environ[var]
yield yield
# Restore original values # Restore original values
for var, value in original_values.items(): for var, value in original_values.items():
if value is not None: if value is not None:
@ -58,6 +65,7 @@ def clean_env():
elif var in os.environ: elif var in os.environ:
del os.environ[var] del os.environ[var]
def test_provider_validation_respects_cli_args(clean_env): def test_provider_validation_respects_cli_args(clean_env):
"""Test that provider validation respects CLI args over defaults.""" """Test that provider validation respects CLI args over defaults."""
# Set up environment with only OpenAI credentials # Set up environment with only OpenAI credentials
@ -65,7 +73,9 @@ def test_provider_validation_respects_cli_args(clean_env):
# Should succeed with OpenAI provider # Should succeed with OpenAI provider
args = MockArgs(provider="openai") args = MockArgs(provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
args
)
assert not expert_missing assert not expert_missing
# Should fail with Anthropic provider even though it's first alphabetically # Should fail with Anthropic provider even though it's first alphabetically
@ -73,62 +83,72 @@ def test_provider_validation_respects_cli_args(clean_env):
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
def test_expert_provider_fallback(clean_env): def test_expert_provider_fallback(clean_env):
"""Test expert provider falls back to main provider keys.""" """Test expert provider falls back to main provider keys."""
os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai", expert_provider="openai") args = MockArgs(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, _, _ = validate_environment(args) expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled assert expert_enabled
assert not expert_missing assert not expert_missing
def test_openai_compatible_base_url(clean_env): def test_openai_compatible_base_url(clean_env):
"""Test OpenAI-compatible provider requires base URL.""" """Test OpenAI-compatible provider requires base URL."""
os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai-compatible") args = MockArgs(provider="openai-compatible")
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
os.environ["OPENAI_API_BASE"] = "http://test" os.environ["OPENAI_API_BASE"] = "http://test"
expert_enabled, expert_missing, _, _ = validate_environment(args) expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_missing assert not expert_missing
def test_expert_provider_separate_keys(clean_env): def test_expert_provider_separate_keys(clean_env):
"""Test expert provider can use separate keys.""" """Test expert provider can use separate keys."""
os.environ["OPENAI_API_KEY"] = "main-key" os.environ["OPENAI_API_KEY"] = "main-key"
os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key" os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key"
args = MockArgs(provider="openai", expert_provider="openai") args = MockArgs(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, _, _ = validate_environment(args) expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled assert expert_enabled
assert not expert_missing assert not expert_missing
def test_web_research_independent(clean_env): def test_web_research_independent(clean_env):
"""Test web research validation is independent of provider.""" """Test web research validation is independent of provider."""
os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai") args = MockArgs(provider="openai")
# Without Tavily key # Without Tavily key
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
args
)
assert not web_enabled assert not web_enabled
assert web_missing assert web_missing
# With Tavily key # With Tavily key
os.environ["TAVILY_API_KEY"] = "test-key" os.environ["TAVILY_API_KEY"] = "test-key"
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
args
)
assert web_enabled assert web_enabled
assert not web_missing assert not web_missing
def test_provider_factory_unknown_provider(clean_env): def test_provider_factory_unknown_provider(clean_env):
"""Test provider factory handles unknown providers.""" """Test provider factory handles unknown providers."""
strategy = ProviderFactory.create("unknown") strategy = ProviderFactory.create("unknown")
assert strategy is None assert strategy is None
args = MockArgs(provider="unknown") args = MockArgs(provider="unknown")
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
def test_provider_strategy_validation(clean_env): def test_provider_strategy_validation(clean_env):
"""Test individual provider strategies.""" """Test individual provider strategies."""
# Test Anthropic strategy # Test Anthropic strategy
@ -154,35 +174,38 @@ def test_provider_strategy_validation(clean_env):
assert result.valid assert result.valid
assert not result.missing_vars assert not result.missing_vars
def test_missing_provider_arg(): def test_missing_provider_arg():
"""Test handling of missing provider argument.""" """Test handling of missing provider argument."""
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(None) validate_environment(None)
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(MockArgs(provider=None)) validate_environment(MockArgs(provider=None))
def test_empty_provider_arg(): def test_empty_provider_arg():
"""Test handling of empty provider argument.""" """Test handling of empty provider argument."""
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(MockArgs(provider="")) validate_environment(MockArgs(provider=""))
def test_incomplete_openai_compatible_config(clean_env): def test_incomplete_openai_compatible_config(clean_env):
"""Test OpenAI-compatible provider with incomplete configuration.""" """Test OpenAI-compatible provider with incomplete configuration."""
strategy = OpenAICompatibleStrategy() strategy = OpenAICompatibleStrategy()
# No configuration # No configuration
result = strategy.validate() result = strategy.validate()
assert not result.valid assert not result.valid
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
# Only API key # Only API key
os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key"
result = strategy.validate() result = strategy.validate()
assert not result.valid assert not result.valid
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
# Only base URL # Only base URL
os.environ.pop("OPENAI_API_KEY") os.environ.pop("OPENAI_API_KEY")
os.environ["OPENAI_API_BASE"] = "http://test" os.environ["OPENAI_API_BASE"] = "http://test"
@ -194,29 +217,30 @@ def test_incomplete_openai_compatible_config(clean_env):
def test_incomplete_gemini_config(clean_env): def test_incomplete_gemini_config(clean_env):
"""Test Gemini provider with incomplete configuration.""" """Test Gemini provider with incomplete configuration."""
strategy = GeminiStrategy() strategy = GeminiStrategy()
# No configuration # No configuration
result = strategy.validate() result = strategy.validate()
assert not result.valid assert not result.valid
assert "GEMINI_API_KEY environment variable is not set" in result.missing_vars assert "GEMINI_API_KEY environment variable is not set" in result.missing_vars
# Valid API key # Valid API key
os.environ["GEMINI_API_KEY"] = "test-key" os.environ["GEMINI_API_KEY"] = "test-key"
result = strategy.validate() result = strategy.validate()
assert result.valid assert result.valid
assert not result.missing_vars assert not result.missing_vars
def test_incomplete_expert_config(clean_env): def test_incomplete_expert_config(clean_env):
"""Test expert provider with incomplete configuration.""" """Test expert provider with incomplete configuration."""
# Set main provider but not expert # Set main provider but not expert
os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key"
args = MockArgs(provider="openai", expert_provider="openai-compatible") args = MockArgs(provider="openai", expert_provider="openai-compatible")
expert_enabled, expert_missing, _, _ = validate_environment(args) expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled assert not expert_enabled
assert len(expert_missing) == 1 assert len(expert_missing) == 1
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0] assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
# Set expert key but not base URL # Set expert key but not base URL
os.environ["EXPERT_OPENAI_API_KEY"] = "test-key" os.environ["EXPERT_OPENAI_API_KEY"] = "test-key"
expert_enabled, expert_missing, _, _ = validate_environment(args) expert_enabled, expert_missing, _, _ = validate_environment(args)
@ -232,7 +256,7 @@ def test_empty_environment_variables(clean_env):
args = MockArgs(provider="openai") args = MockArgs(provider="openai")
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
# Empty base URL # Empty base URL
os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["OPENAI_API_BASE"] = "" os.environ["OPENAI_API_BASE"] = ""
@ -240,10 +264,11 @@ def test_empty_environment_variables(clean_env):
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
validate_environment(args) validate_environment(args)
def test_openrouter_validation(clean_env): def test_openrouter_validation(clean_env):
"""Test OpenRouter provider validation.""" """Test OpenRouter provider validation."""
strategy = OpenRouterStrategy() strategy = OpenRouterStrategy()
# No API key # No API key
result = strategy.validate() result = strategy.validate()
assert not result.valid assert not result.valid
@ -254,13 +279,14 @@ def test_openrouter_validation(clean_env):
result = strategy.validate() result = strategy.validate()
assert not result.valid assert not result.valid
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
# Valid API key # Valid API key
os.environ["OPENROUTER_API_KEY"] = "test-key" os.environ["OPENROUTER_API_KEY"] = "test-key"
result = strategy.validate() result = strategy.validate()
assert result.valid assert result.valid
assert not result.missing_vars assert not result.missing_vars
def test_multiple_expert_providers(clean_env): def test_multiple_expert_providers(clean_env):
"""Test validation with multiple expert providers.""" """Test validation with multiple expert providers."""
os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_KEY"] = "test-key"
@ -268,14 +294,17 @@ def test_multiple_expert_providers(clean_env):
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307" os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
# First expert provider valid, second invalid # First expert provider valid, second invalid
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307") args = MockArgs(
provider="openai",
expert_provider="anthropic",
expert_model="claude-3-haiku-20240307",
)
expert_enabled, expert_missing, _, _ = validate_environment(args) expert_enabled, expert_missing, _, _ = validate_environment(args)
assert expert_enabled assert expert_enabled
assert not expert_missing assert not expert_missing
# Switch to invalid provider # Switch to invalid provider
args = MockArgs(provider="openai", expert_provider="openai-compatible") args = MockArgs(provider="openai", expert_provider="openai-compatible")
expert_enabled, expert_missing, _, _ = validate_environment(args) expert_enabled, expert_missing, _, _ = validate_environment(args)
assert not expert_enabled assert not expert_enabled
assert expert_missing assert expert_missing

3643
uv.lock Normal file

File diff suppressed because it is too large Load Diff