RA.Aid/tests/ra_aid/test_default_provider.py

97 lines
2.7 KiB
Python

"""Tests for default provider and model configuration."""
from dataclasses import dataclass
from typing import Optional
import pytest
from ra_aid.__main__ import parse_arguments
from ra_aid.env import validate_environment
@dataclass
class MockArgs:
"""Mock arguments for testing."""
provider: Optional[str] = None
expert_provider: Optional[str] = None
model: Optional[str] = None
expert_model: Optional[str] = None
message: Optional[str] = None
research_only: bool = False
chat: bool = False
@pytest.fixture
def clean_env(monkeypatch):
"""Remove all provider-related environment variables."""
env_vars = [
"ANTHROPIC_API_KEY",
"OPENAI_API_KEY",
"OPENROUTER_API_KEY",
"OPENAI_API_BASE",
"EXPERT_ANTHROPIC_API_KEY",
"EXPERT_OPENAI_API_KEY",
"EXPERT_OPENAI_API_BASE",
"TAVILY_API_KEY",
"ANTHROPIC_MODEL",
]
for var in env_vars:
monkeypatch.delenv(var, raising=False)
yield
def test_default_anthropic_provider(clean_env, monkeypatch):
"""Test that Anthropic is the default provider when no environment variables are set."""
args = parse_arguments(["-m", "test message"])
assert args.provider == "anthropic"
assert args.model == "claude-3-5-sonnet-20241022"
TEST_CASES = [
pytest.param(
"research_only_no_provider",
MockArgs(research_only=True),
{},
"No provider specified",
id="research_only_no_provider",
),
pytest.param(
"research_only_anthropic",
MockArgs(research_only=True, provider="anthropic"),
{},
None,
id="research_only_anthropic",
),
pytest.param(
"research_only_non_anthropic_no_model",
MockArgs(research_only=True, provider="openai"),
{},
"Model is required for non-Anthropic providers",
id="research_only_non_anthropic_no_model",
),
pytest.param(
"research_only_non_anthropic_with_model",
MockArgs(research_only=True, provider="openai", model="gpt-4"),
{},
None,
id="research_only_non_anthropic_with_model",
),
]
@pytest.mark.parametrize("test_name,args,env_vars,expected_error", TEST_CASES)
def test_research_only_provider_validation(
test_name: str, args: MockArgs, env_vars: dict, expected_error: str, monkeypatch
):
"""Test provider and model validation in research-only mode."""
# Set test environment variables
for key, value in env_vars.items():
monkeypatch.setenv(key, value)
if expected_error:
with pytest.raises(SystemExit, match=expected_error):
validate_environment(args)
else:
validate_environment(args)