Added Provider/Model Override Arguments to Seperate Research/Planner Configurations (#53)
* feat: add research and planner provider/model options to enhance configurability for research and planning tasks refactor: create get_effective_model_config function to streamline provider/model resolution logic test: add unit tests for effective model configuration and environment validation for research and planner providers * refactor(agent_utils.py): remove get_effective_model_config function to simplify code and improve readability style(agent_utils.py): format debug log statements for better readability fix(agent_utils.py): update run_agent functions to directly use config without effective model config feat(agent_utils.py): enhance logging for command execution in programmer.py test(tests): remove tests related to get_effective_model_config function as it has been removed * chore(tests): remove outdated tests for research and planner agent configurations to clean up the test suite and improve maintainability * style(tests): apply consistent formatting and spacing in test_provider_integration.py for improved readability and maintainability
This commit is contained in:
parent
54fdebfc3a
commit
6c4acfea8b
|
|
@ -166,6 +166,10 @@ ra-aid -m "Add new feature" --verbose
|
|||
- `--research-only`: Only perform research without implementation
|
||||
- `--provider`: The LLM provider to use (choices: anthropic, openai, openrouter, openai-compatible, gemini)
|
||||
- `--model`: The model name to use (required for non-Anthropic providers)
|
||||
- `--research-provider`: Provider to use specifically for research tasks (falls back to --provider if not specified)
|
||||
- `--research-model`: Model to use specifically for research tasks (falls back to --model if not specified)
|
||||
- `--planner-provider`: Provider to use specifically for planning tasks (falls back to --provider if not specified)
|
||||
- `--planner-model`: Model to use specifically for planning tasks (falls back to --model if not specified)
|
||||
- `--cowboy-mode`: Skip interactive approval for shell commands
|
||||
- `--expert-provider`: The LLM provider to use for expert knowledge queries (choices: anthropic, openai, openrouter, openai-compatible, gemini)
|
||||
- `--expert-model`: The model name to use for expert knowledge queries (required for non-OpenAI providers)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,7 @@ from rich.console import Console
|
|||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from ra_aid.config import DEFAULT_RECURSION_LIMIT
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.project_info import (
|
||||
get_project_info,
|
||||
format_project_info,
|
||||
display_project_status,
|
||||
)
|
||||
from ra_aid.project_info import get_project_info, format_project_info
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.tools.human import ask_human
|
||||
from ra_aid import print_stage_header, print_error
|
||||
|
|
@ -81,6 +77,26 @@ Examples:
|
|||
help="The LLM provider to use",
|
||||
)
|
||||
parser.add_argument("--model", type=str, help="The model name to use")
|
||||
parser.add_argument(
|
||||
"--research-provider",
|
||||
type=str,
|
||||
choices=VALID_PROVIDERS,
|
||||
help="Provider to use specifically for research tasks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--research-model",
|
||||
type=str,
|
||||
help="Model to use specifically for research tasks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--planner-provider",
|
||||
type=str,
|
||||
choices=VALID_PROVIDERS,
|
||||
help="Provider to use specifically for planning tasks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--planner-model", type=str, help="Model to use specifically for planning tasks"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cowboy-mode",
|
||||
action="store_true",
|
||||
|
|
@ -130,9 +146,7 @@ Examples:
|
|||
help="Maximum recursion depth for agent operations (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--aider-config',
|
||||
type=str,
|
||||
help='Specify the aider config file path'
|
||||
"--aider-config", type=str, help="Specify the aider config file path"
|
||||
)
|
||||
|
||||
if args is None:
|
||||
|
|
@ -223,7 +237,7 @@ def main():
|
|||
if expert_missing:
|
||||
console.print(
|
||||
Panel(
|
||||
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
|
||||
"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
|
||||
+ "\n".join(f"- {m}" for m in expert_missing)
|
||||
+ "\nSet the required environment variables or args to enable expert mode.",
|
||||
title="Expert Tools Disabled",
|
||||
|
|
@ -234,7 +248,7 @@ def main():
|
|||
if web_research_missing:
|
||||
console.print(
|
||||
Panel(
|
||||
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n"
|
||||
"[yellow]Web research disabled due to missing configuration:[/yellow]\n"
|
||||
+ "\n".join(f"- {m}" for m in web_research_missing)
|
||||
+ "\nSet the required environment variables to enable web research.",
|
||||
title="Web Research Disabled",
|
||||
|
|
@ -242,11 +256,12 @@ def main():
|
|||
)
|
||||
)
|
||||
|
||||
# Create the base model after validation
|
||||
model = initialize_llm(args.provider, args.model, temperature=args.temperature)
|
||||
|
||||
# Handle chat mode
|
||||
if args.chat:
|
||||
# Initialize chat model with default provider/model
|
||||
chat_model = initialize_llm(
|
||||
args.provider, args.model, temperature=args.temperature
|
||||
)
|
||||
if args.research_only:
|
||||
print_error("Chat mode cannot be used with --research-only")
|
||||
sys.exit(1)
|
||||
|
|
@ -291,7 +306,7 @@ def main():
|
|||
|
||||
# Create chat agent with appropriate tools
|
||||
chat_agent = create_agent(
|
||||
model,
|
||||
chat_model,
|
||||
get_chat_tools(
|
||||
expert_enabled=expert_enabled,
|
||||
web_research_enabled=web_research_enabled,
|
||||
|
|
@ -334,20 +349,39 @@ def main():
|
|||
# Store config in global memory for access by is_informational_query
|
||||
_global_memory["config"] = config
|
||||
|
||||
# Store model configuration
|
||||
# Store base provider/model configuration
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
|
||||
# Store expert provider and model in config
|
||||
# Store expert provider/model (no fallback)
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
|
||||
# Store planner config with fallback to base values
|
||||
_global_memory["config"]["planner_provider"] = (
|
||||
args.planner_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||
|
||||
# Store research config with fallback to base values
|
||||
_global_memory["config"]["research_provider"] = (
|
||||
args.research_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["research_model"] = args.research_model or args.model
|
||||
|
||||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
||||
# Initialize research model with potential overrides
|
||||
research_provider = args.research_provider or args.provider
|
||||
research_model_name = args.research_model or args.model
|
||||
research_model = initialize_llm(
|
||||
research_provider, research_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
run_research_agent(
|
||||
base_task,
|
||||
model,
|
||||
research_model,
|
||||
expert_enabled=expert_enabled,
|
||||
research_only=args.research_only,
|
||||
hil=args.hil,
|
||||
|
|
@ -357,10 +391,17 @@ def main():
|
|||
|
||||
# Proceed with planning and implementation if not an informational query
|
||||
if not is_informational_query():
|
||||
# Initialize planning model with potential overrides
|
||||
planner_provider = args.planner_provider or args.provider
|
||||
planner_model_name = args.planner_model or args.model
|
||||
planning_model = initialize_llm(
|
||||
planner_provider, planner_model_name, temperature=args.temperature
|
||||
)
|
||||
|
||||
# Run planning agent
|
||||
run_planning_agent(
|
||||
base_task,
|
||||
model,
|
||||
planning_model,
|
||||
expert_enabled=expert_enabled,
|
||||
hil=args.hil,
|
||||
memory=planning_memory,
|
||||
|
|
|
|||
|
|
@ -142,12 +142,18 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
|||
model_info = get_model_info(provider_model)
|
||||
max_input_tokens = model_info.get("max_input_tokens")
|
||||
if max_input_tokens:
|
||||
logger.debug(f"Using litellm token limit for {model_name}: {max_input_tokens}")
|
||||
logger.debug(
|
||||
f"Using litellm token limit for {model_name}: {max_input_tokens}"
|
||||
)
|
||||
return max_input_tokens
|
||||
except litellm.exceptions.NotFoundError:
|
||||
logger.debug(f"Model {model_name} not found in litellm, falling back to models_tokens")
|
||||
logger.debug(
|
||||
f"Model {model_name} not found in litellm, falling back to models_tokens"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting model info from litellm: {e}, falling back to models_tokens")
|
||||
logger.debug(
|
||||
f"Error getting model info from litellm: {e}, falling back to models_tokens"
|
||||
)
|
||||
|
||||
# Fallback to models_tokens dict
|
||||
# Normalize model name for fallback lookup (e.g. claude-2 -> claude2)
|
||||
|
|
@ -155,7 +161,9 @@ def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
|||
provider_tokens = models_tokens.get(provider, {})
|
||||
max_input_tokens = provider_tokens.get(normalized_name, None)
|
||||
if max_input_tokens:
|
||||
logger.debug(f"Found token limit for {provider}/{model_name}: {max_input_tokens}")
|
||||
logger.debug(
|
||||
f"Found token limit for {provider}/{model_name}: {max_input_tokens}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
|
|
@ -360,7 +368,10 @@ def run_research_agent(
|
|||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": recursion_limit,
|
||||
}
|
||||
if config:
|
||||
run_config.update(config)
|
||||
|
||||
|
|
@ -467,8 +478,12 @@ def run_web_research_agent(
|
|||
)
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": recursion_limit,
|
||||
}
|
||||
if config:
|
||||
run_config.update(config)
|
||||
|
||||
|
|
@ -551,8 +566,12 @@ def run_planning_agent(
|
|||
)
|
||||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": recursion_limit,
|
||||
}
|
||||
if config:
|
||||
run_config.update(config)
|
||||
|
||||
|
|
@ -642,7 +661,10 @@ def run_task_implementation_agent(
|
|||
|
||||
config = _global_memory.get("config", {}) if not config else config
|
||||
recursion_limit = config.get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": recursion_limit}
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": recursion_limit,
|
||||
}
|
||||
if config:
|
||||
run_config.update(config)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ from langchain_anthropic import ChatAnthropic
|
|||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
|
||||
"""Get environment variable with optional expert prefix and fallback."""
|
||||
|
|
@ -122,7 +124,18 @@ def create_llm_client(
|
|||
if not config:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
# Only pass temperature if it's explicitly set and not in expert mode
|
||||
logger.debug(
|
||||
"Creating LLM client with provider=%s, model=%s, temperature=%s, expert=%s",
|
||||
provider,
|
||||
model_name,
|
||||
temperature,
|
||||
is_expert
|
||||
)
|
||||
|
||||
# Handle temperature for expert mode
|
||||
if is_expert:
|
||||
temperature = 0
|
||||
|
||||
temp_kwargs = {}
|
||||
if not is_expert and temperature is not None:
|
||||
temp_kwargs = {"temperature": temperature}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from typing import List, Dict, Union
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from langchain_core.tools import tool
|
||||
from rich.console import Console
|
||||
|
|
@ -10,6 +11,7 @@ from ra_aid.proc.interactive import run_interactive_command
|
|||
from ra_aid.text.processing import truncate_output
|
||||
|
||||
console = Console()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@tool
|
||||
def run_programming_task(instructions: str, files: List[str] = []) -> Dict[str, Union[str, int, bool]]:
|
||||
|
|
@ -81,6 +83,7 @@ Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True
|
|||
|
||||
markdown_content = "".join(task_display)
|
||||
console.print(Panel(Markdown(markdown_content), title="🤖 Aider Task", border_style="bright_blue"))
|
||||
logger.debug(f"command: {command}")
|
||||
|
||||
try:
|
||||
# Run the command interactively
|
||||
|
|
|
|||
|
|
@ -60,41 +60,43 @@ def test_get_model_token_limit_missing_config(mock_memory):
|
|||
assert token_limit is None
|
||||
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_success():
|
||||
"""Test get_model_token_limit successfully getting limit from litellm."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.return_value = {"max_input_tokens": 100000}
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit == 100000
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_not_found():
|
||||
"""Test fallback to models_tokens when litellm raises NotFoundError."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.side_effect = litellm.exceptions.NotFoundError(
|
||||
message="Model not found",
|
||||
model="claude-2",
|
||||
llm_provider="anthropic"
|
||||
message="Model not found", model="claude-2", llm_provider="anthropic"
|
||||
)
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit == models_tokens["anthropic"]["claude2"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_litellm_error():
|
||||
"""Test fallback to models_tokens when litellm raises other exceptions."""
|
||||
config = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch('ra_aid.agent_utils.get_model_info') as mock_get_info:
|
||||
|
||||
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
|
||||
mock_get_info.side_effect = Exception("Unknown error")
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit == models_tokens["anthropic"]["claude2"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_unexpected_error():
|
||||
"""Test returning None when unexpected errors occur."""
|
||||
config = None # This will cause an attribute error when accessed
|
||||
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit is None
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ class MockArgs:
|
|||
expert_provider: str
|
||||
model: Optional[str] = None
|
||||
expert_model: Optional[str] = None
|
||||
research_provider: Optional[str] = None
|
||||
research_model: Optional[str] = None
|
||||
planner_provider: Optional[str] = None
|
||||
planner_model: Optional[str] = None
|
||||
|
||||
@pytest.fixture
|
||||
def clean_env(monkeypatch):
|
||||
|
|
@ -166,6 +170,7 @@ def test_different_providers_no_expert_key(clean_env, monkeypatch):
|
|||
assert not web_research_enabled
|
||||
assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing
|
||||
|
||||
|
||||
def test_mixed_provider_openai_compatible(clean_env, monkeypatch):
|
||||
"""Test behavior with openai-compatible expert and different main provider"""
|
||||
args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307")
|
||||
|
|
|
|||
|
|
@ -121,3 +121,45 @@ def test_missing_message():
|
|||
# Verify message is captured when provided
|
||||
args = parse_arguments(["-m", "test"])
|
||||
assert args.message == "test"
|
||||
|
||||
|
||||
def test_research_model_provider_args(mock_dependencies):
|
||||
"""Test that research-specific model/provider args are correctly stored in config."""
|
||||
from ra_aid.__main__ import main
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
_global_memory.clear()
|
||||
|
||||
with patch.object(sys, 'argv', [
|
||||
'ra-aid', '-m', 'test message',
|
||||
'--research-provider', 'anthropic',
|
||||
'--research-model', 'claude-3-haiku-20240307',
|
||||
'--planner-provider', 'openai',
|
||||
'--planner-model', 'gpt-4'
|
||||
]):
|
||||
main()
|
||||
config = _global_memory["config"]
|
||||
assert config["research_provider"] == "anthropic"
|
||||
assert config["research_model"] == "claude-3-haiku-20240307"
|
||||
assert config["planner_provider"] == "openai"
|
||||
assert config["planner_model"] == "gpt-4"
|
||||
|
||||
|
||||
def test_planner_model_provider_args(mock_dependencies):
|
||||
"""Test that planner provider/model args fall back to main config when not specified."""
|
||||
from ra_aid.__main__ import main
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
_global_memory.clear()
|
||||
|
||||
with patch.object(sys, 'argv', [
|
||||
'ra-aid', '-m', 'test message',
|
||||
'--provider', 'openai',
|
||||
'--model', 'gpt-4'
|
||||
]):
|
||||
main()
|
||||
config = _global_memory["config"]
|
||||
assert config["planner_provider"] == "openai"
|
||||
assert config["planner_model"] == "gpt-4"
|
||||
|
|
|
|||
|
|
@ -13,16 +13,23 @@ from ra_aid.provider_strategy import (
|
|||
OpenAIStrategy,
|
||||
OpenAICompatibleStrategy,
|
||||
OpenRouterStrategy,
|
||||
GeminiStrategy
|
||||
GeminiStrategy,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockArgs:
|
||||
"""Mock arguments for testing."""
|
||||
|
||||
provider: str
|
||||
expert_provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
expert_model: Optional[str] = None
|
||||
research_provider: Optional[str] = None
|
||||
research_model: Optional[str] = None
|
||||
planner_provider: Optional[str] = None
|
||||
planner_model: Optional[str] = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_env():
|
||||
|
|
@ -39,18 +46,18 @@ def clean_env():
|
|||
"ANTHROPIC_MODEL",
|
||||
"GEMINI_API_KEY",
|
||||
"EXPERT_GEMINI_API_KEY",
|
||||
"GEMINI_MODEL"
|
||||
"GEMINI_MODEL",
|
||||
]
|
||||
|
||||
|
||||
# Store original values
|
||||
original_values = {}
|
||||
for var in env_vars:
|
||||
original_values[var] = os.environ.get(var)
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
|
||||
yield
|
||||
|
||||
|
||||
# Restore original values
|
||||
for var, value in original_values.items():
|
||||
if value is not None:
|
||||
|
|
@ -58,6 +65,7 @@ def clean_env():
|
|||
elif var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
|
||||
def test_provider_validation_respects_cli_args(clean_env):
|
||||
"""Test that provider validation respects CLI args over defaults."""
|
||||
# Set up environment with only OpenAI credentials
|
||||
|
|
@ -65,7 +73,9 @@ def test_provider_validation_respects_cli_args(clean_env):
|
|||
|
||||
# Should succeed with OpenAI provider
|
||||
args = MockArgs(provider="openai")
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
|
||||
args
|
||||
)
|
||||
assert not expert_missing
|
||||
|
||||
# Should fail with Anthropic provider even though it's first alphabetically
|
||||
|
|
@ -73,62 +83,72 @@ def test_provider_validation_respects_cli_args(clean_env):
|
|||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
|
||||
def test_expert_provider_fallback(clean_env):
|
||||
"""Test expert provider falls back to main provider keys."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
args = MockArgs(provider="openai", expert_provider="openai")
|
||||
|
||||
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
|
||||
|
||||
def test_openai_compatible_base_url(clean_env):
|
||||
"""Test OpenAI-compatible provider requires base URL."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
args = MockArgs(provider="openai-compatible")
|
||||
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
|
||||
os.environ["OPENAI_API_BASE"] = "http://test"
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert not expert_missing
|
||||
|
||||
|
||||
def test_expert_provider_separate_keys(clean_env):
|
||||
"""Test expert provider can use separate keys."""
|
||||
os.environ["OPENAI_API_KEY"] = "main-key"
|
||||
os.environ["EXPERT_OPENAI_API_KEY"] = "expert-key"
|
||||
|
||||
|
||||
args = MockArgs(provider="openai", expert_provider="openai")
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
|
||||
|
||||
def test_web_research_independent(clean_env):
|
||||
"""Test web research validation is independent of provider."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
args = MockArgs(provider="openai")
|
||||
|
||||
|
||||
# Without Tavily key
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
|
||||
args
|
||||
)
|
||||
assert not web_enabled
|
||||
assert web_missing
|
||||
|
||||
|
||||
# With Tavily key
|
||||
os.environ["TAVILY_API_KEY"] = "test-key"
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(
|
||||
args
|
||||
)
|
||||
assert web_enabled
|
||||
assert not web_missing
|
||||
|
||||
|
||||
def test_provider_factory_unknown_provider(clean_env):
|
||||
"""Test provider factory handles unknown providers."""
|
||||
strategy = ProviderFactory.create("unknown")
|
||||
assert strategy is None
|
||||
|
||||
|
||||
args = MockArgs(provider="unknown")
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
|
||||
def test_provider_strategy_validation(clean_env):
|
||||
"""Test individual provider strategies."""
|
||||
# Test Anthropic strategy
|
||||
|
|
@ -154,35 +174,38 @@ def test_provider_strategy_validation(clean_env):
|
|||
assert result.valid
|
||||
assert not result.missing_vars
|
||||
|
||||
|
||||
def test_missing_provider_arg():
|
||||
"""Test handling of missing provider argument."""
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(None)
|
||||
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(MockArgs(provider=None))
|
||||
|
||||
|
||||
def test_empty_provider_arg():
|
||||
"""Test handling of empty provider argument."""
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(MockArgs(provider=""))
|
||||
|
||||
|
||||
def test_incomplete_openai_compatible_config(clean_env):
|
||||
"""Test OpenAI-compatible provider with incomplete configuration."""
|
||||
strategy = OpenAICompatibleStrategy()
|
||||
|
||||
|
||||
# No configuration
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENAI_API_KEY environment variable is not set" in result.missing_vars
|
||||
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
||||
|
||||
|
||||
# Only API key
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENAI_API_BASE environment variable is not set" in result.missing_vars
|
||||
|
||||
|
||||
# Only base URL
|
||||
os.environ.pop("OPENAI_API_KEY")
|
||||
os.environ["OPENAI_API_BASE"] = "http://test"
|
||||
|
|
@ -194,29 +217,30 @@ def test_incomplete_openai_compatible_config(clean_env):
|
|||
def test_incomplete_gemini_config(clean_env):
|
||||
"""Test Gemini provider with incomplete configuration."""
|
||||
strategy = GeminiStrategy()
|
||||
|
||||
|
||||
# No configuration
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "GEMINI_API_KEY environment variable is not set" in result.missing_vars
|
||||
|
||||
|
||||
# Valid API key
|
||||
os.environ["GEMINI_API_KEY"] = "test-key"
|
||||
result = strategy.validate()
|
||||
assert result.valid
|
||||
assert not result.missing_vars
|
||||
|
||||
|
||||
def test_incomplete_expert_config(clean_env):
|
||||
"""Test expert provider with incomplete configuration."""
|
||||
# Set main provider but not expert
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
||||
|
||||
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert not expert_enabled
|
||||
assert len(expert_missing) == 1
|
||||
assert "EXPERT_OPENAI_API_BASE" in expert_missing[0]
|
||||
|
||||
|
||||
# Set expert key but not base URL
|
||||
os.environ["EXPERT_OPENAI_API_KEY"] = "test-key"
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
|
|
@ -232,7 +256,7 @@ def test_empty_environment_variables(clean_env):
|
|||
args = MockArgs(provider="openai")
|
||||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
|
||||
# Empty base URL
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
os.environ["OPENAI_API_BASE"] = ""
|
||||
|
|
@ -240,10 +264,11 @@ def test_empty_environment_variables(clean_env):
|
|||
with pytest.raises(SystemExit):
|
||||
validate_environment(args)
|
||||
|
||||
|
||||
def test_openrouter_validation(clean_env):
|
||||
"""Test OpenRouter provider validation."""
|
||||
strategy = OpenRouterStrategy()
|
||||
|
||||
|
||||
# No API key
|
||||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
|
|
@ -254,13 +279,14 @@ def test_openrouter_validation(clean_env):
|
|||
result = strategy.validate()
|
||||
assert not result.valid
|
||||
assert "OPENROUTER_API_KEY environment variable is not set" in result.missing_vars
|
||||
|
||||
|
||||
# Valid API key
|
||||
os.environ["OPENROUTER_API_KEY"] = "test-key"
|
||||
result = strategy.validate()
|
||||
assert result.valid
|
||||
assert not result.missing_vars
|
||||
|
||||
|
||||
def test_multiple_expert_providers(clean_env):
|
||||
"""Test validation with multiple expert providers."""
|
||||
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||
|
|
@ -268,14 +294,17 @@ def test_multiple_expert_providers(clean_env):
|
|||
os.environ["ANTHROPIC_MODEL"] = "claude-3-haiku-20240307"
|
||||
|
||||
# First expert provider valid, second invalid
|
||||
args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307")
|
||||
args = MockArgs(
|
||||
provider="openai",
|
||||
expert_provider="anthropic",
|
||||
expert_model="claude-3-haiku-20240307",
|
||||
)
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
|
||||
|
||||
# Switch to invalid provider
|
||||
args = MockArgs(provider="openai", expert_provider="openai-compatible")
|
||||
expert_enabled, expert_missing, _, _ = validate_environment(args)
|
||||
assert not expert_enabled
|
||||
assert expert_missing
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue